576 lines
66 KiB
HTML
576 lines
66 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Matrix Multiplication — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||
<script src="../../_static/jquery.js"></script>
|
||
<script src="../../_static/underscore.js"></script>
|
||
<script src="../../_static/doctools.js"></script>
|
||
|
||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
<link rel="next" title="triton" href="../../python-api/triton.html" />
|
||
<link rel="prev" title="Fused Softmax" href="02-fused-softmax.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Matrix Multiplication</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#pointer-arithmetics">Pointer Arithmetics</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#l2-cache-optimizations">L2 Cache Optimizations</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#final-result">Final Result</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#square-matrix-performance">Square Matrix Performance</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li><a href="index.html">Tutorials</a> »</li>
|
||
|
||
<li>Matrix Multiplication</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../../_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<div class="sphx-glr-download-link-note admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">here</span></a>
|
||
to download the full example code</p>
|
||
</div>
|
||
<div class="sphx-glr-example-title section" id="matrix-multiplication">
|
||
<span id="sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"></span><h1>Matrix Multiplication<a class="headerlink" href="#matrix-multiplication" title="Permalink to this headline">¶</a></h1>
|
||
<p>In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
|
||
You will specifically learn about:</p>
|
||
<ul class="simple">
|
||
<li><p>Block-level matrix multiplications</p></li>
|
||
<li><p>Multi-dimensional pointer arithmetic</p></li>
|
||
<li><p>Program re-ordering for improved L2 cache hit rate</p></li>
|
||
<li><p>Automatic performance tuning</p></li>
|
||
</ul>
|
||
<div class="section" id="motivations">
|
||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||
<p>Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||
They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
|
||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||
In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.</p>
|
||
<p>Roughly speaking, the kernel that we will write will implement the following blocked algorithm:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
|
||
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">):</span>
|
||
<span class="c1"># do in parallel</span>
|
||
<span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">):</span>
|
||
<span class="n">acc</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">):</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">]</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span>
|
||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="n">C</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.</p>
|
||
</div>
|
||
<div class="section" id="compute-kernel">
|
||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||
<p>The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||
The main difficulty comes from the computation of the memory locations at which blocks of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.</p>
|
||
<div class="section" id="pointer-arithmetics">
|
||
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given by <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_x_0</span> <span class="pre">+</span> <span class="pre">j*stride_x_1</span></code>.
|
||
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_M,</span> <span class="pre">k:k+BLOCK_K]</span></code> and <code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_N]</span></code> can be defined in pseudo-code as:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>Which means that pointers for blocks of A and B can be initialized (i.e., <code class="code docutils literal notranslate"><span class="pre">k=0</span></code>) in Triton as:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pid_m</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||
<span class="n">rk</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span>
|
||
<span class="o">//</span> <span class="n">pointer</span> <span class="k">for</span> <span class="n">A</span> <span class="n">operand</span>
|
||
<span class="n">pa</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_a_0</span> <span class="o">+</span> <span class="n">rk</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">);</span>
|
||
<span class="o">//</span> <span class="n">pointer</span> <span class="k">for</span> <span class="n">B</span> <span class="n">operand</span>
|
||
<span class="n">pb</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_b_0</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_b_1</span><span class="p">);</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>And then updated in the inner loop as follows:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">;</span>
|
||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_b_0</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
</div>
|
||
<div class="section" id="l2-cache-optimizations">
|
||
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline">¶</a></h3>
|
||
<p>As mentioned above, each program instance computes an <code class="code docutils literal notranslate"><span class="pre">[BLOCK_M,</span> <span class="pre">BLOCK_N]</span></code> block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
|
||
It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
|
||
And unfortunately, a simple row-major ordering</p>
|
||
<blockquote>
|
||
<div><div class="highlight-Python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_M</span><span class="p">;</span>
|
||
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_N</span><span class="p">;</span>
|
||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">%</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>is just not going to cut it.</p>
|
||
<p>One possible solution is to launch blocks in an order that promotes data reuse.
|
||
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before switching to the next column:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="n">width</span> <span class="o">=</span> <span class="n">GROUP_M</span> <span class="o">*</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">width</span><span class="p">;</span>
|
||
<span class="c1"># we need to handle the case where M % (GROUP_M*BLOCK_M) != 0</span>
|
||
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span><span class="p">,</span> <span class="n">GROUP_M</span><span class="p">);</span>
|
||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">);</span>
|
||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">group_size</span><span class="p">);</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>In practice, this can improve the performance of our matrix multiplication kernel by >10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="final-result">
|
||
<h2>Final Result<a class="headerlink" href="#final-result" title="Permalink to this headline">¶</a></h2>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||
|
||
<span class="c1"># %</span>
|
||
<span class="c1"># :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:</span>
|
||
<span class="c1"># - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try</span>
|
||
<span class="c1"># - A autotuning *key* whose change in values will trigger evaluation of all the provided configs</span>
|
||
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span>
|
||
<span class="n">configs</span><span class="o">=</span><span class="p">[</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>\
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="c1">#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),</span>
|
||
<span class="p">],</span>
|
||
<span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span>
|
||
<span class="p">)</span>
|
||
<span class="c1"># %</span>
|
||
<span class="c1"># We can now define our kernel as normal, using all the techniques presented above</span>
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">_matmul</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">stride_am</span><span class="p">,</span> <span class="n">stride_ak</span><span class="p">,</span> <span class="n">stride_bk</span><span class="p">,</span> <span class="n">stride_bn</span><span class="p">,</span> <span class="n">stride_cm</span><span class="p">,</span> <span class="n">stride_cn</span><span class="p">,</span> <span class="o">**</span><span class="n">META</span><span class="p">):</span>
|
||
<span class="c1"># extract meta-parameters</span>
|
||
<span class="n">BLOCK_M</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_M'</span><span class="p">]</span>
|
||
<span class="n">BLOCK_N</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_N'</span><span class="p">]</span>
|
||
<span class="n">BLOCK_K</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_K'</span><span class="p">]</span>
|
||
<span class="n">GROUP_M</span> <span class="o">=</span> <span class="mi">8</span>
|
||
<span class="c1"># matrix multiplication</span>
|
||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_M</span>
|
||
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_N</span>
|
||
<span class="c1"># re-order program ID for better L2 performance</span>
|
||
<span class="n">width</span> <span class="o">=</span> <span class="n">GROUP_M</span> <span class="o">*</span> <span class="n">grid_n</span>
|
||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">width</span>
|
||
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span><span class="p">,</span> <span class="n">GROUP_M</span><span class="p">)</span>
|
||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">)</span>
|
||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">group_size</span><span class="p">)</span>
|
||
<span class="c1"># do matrix multiplication</span>
|
||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||
<span class="n">rk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span>
|
||
<span class="n">A</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_am</span> <span class="o">+</span> <span class="n">rk</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">)</span>
|
||
<span class="n">B</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_bk</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_bn</span><span class="p">)</span>
|
||
<span class="n">acc</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="n">BLOCK_K</span><span class="p">):</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
|
||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="n">A</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
|
||
<span class="n">B</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
||
<span class="c1"># triton can accept arbitrary activation function</span>
|
||
<span class="c1"># via metaparameters!</span>
|
||
<span class="k">if</span> <span class="n">META</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">]:</span>
|
||
<span class="n">acc</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">](</span><span class="n">acc</span><span class="p">)</span>
|
||
<span class="c1"># rematerialize rm and rn to save registers</span>
|
||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||
<span class="n">C</span> <span class="o">=</span> <span class="n">C</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_cm</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_cn</span><span class="p">)</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="n">acc</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
|
||
|
||
<span class="c1"># we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`</span>
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mf">0.01</span><span class="o">*</span><span class="n">x</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We can now create a convenience wrapper function that only takes two input tensors
|
||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="c1"># checks constraints</span>
|
||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">"incompatible dimensions"</span>
|
||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"matrix A must be contiguous"</span>
|
||
<span class="k">assert</span> <span class="n">b</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"matrix B must be contiguous"</span>
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">_</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="c1"># allocates output</span>
|
||
<span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="c1"># launch kernel</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">META</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_M'</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_N'</span><span class="p">]),</span> <span class="p">)</span>
|
||
<span class="n">pgm</span> <span class="o">=</span> <span class="n">_matmul</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> \
|
||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>\
|
||
<span class="n">ACTIVATION</span> <span class="o">=</span> <span class="n">activation</span>
|
||
<span class="p">)</span>
|
||
<span class="c1"># done; return the output tensor</span>
|
||
<span class="k">return</span> <span class="n">c</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="unit-test">
|
||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||
<p>We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">c_0</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">c_1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">c_0</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">c_1</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">c_0</span><span class="p">,</span> <span class="n">c_1</span><span class="p">))</span>
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
|
||
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
|
||
...,
|
||
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
|
||
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
|
||
device='cuda:0', dtype=torch.float16)
|
||
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
|
||
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
|
||
...,
|
||
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
|
||
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
|
||
device='cuda:0', dtype=torch.float16)
|
||
tensor(True, device='cuda:0')
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="benchmark">
|
||
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline">¶</a></h2>
|
||
<div class="section" id="square-matrix-performance">
|
||
<h3>Square Matrix Performance<a class="headerlink" href="#square-matrix-performance" title="Permalink to this headline">¶</a></h3>
|
||
<p>We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'cublas + relu'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'triton + relu'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"cuBLAS (+ torch.nn.LeakyReLU)"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Triton (+ LeakyReLU)"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"TFLOPS"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"matmul-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||
<span class="n">args</span><span class="o">=</span><span class="p">{}</span>
|
||
<span class="p">)</span>
|
||
<span class="p">)</span>
|
||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas + relu'</span><span class="p">:</span>
|
||
<span class="n">torch_relu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch_relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton + relu'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">leaky_relu</span><span class="p">))</span>
|
||
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||
|
||
|
||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<img alt="03 matrix multiplication" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||
0 128.0 0.455111 ... 0.512000 0.512000
|
||
1 256.0 2.730667 ... 2.978909 2.978909
|
||
2 384.0 7.372800 ... 7.899428 7.899428
|
||
3 512.0 14.563555 ... 15.420235 15.420235
|
||
4 640.0 22.260869 ... 24.380953 24.380953
|
||
5 768.0 32.768000 ... 34.028308 34.028308
|
||
6 896.0 39.025776 ... 39.025776 37.971025
|
||
7 1024.0 49.932191 ... 52.428801 52.428801
|
||
8 1152.0 44.566925 ... 46.656000 46.656000
|
||
9 1280.0 51.200001 ... 56.109587 56.109587
|
||
10 1408.0 64.138541 ... 65.684049 65.684049
|
||
11 1536.0 80.430545 ... 76.106321 75.296679
|
||
12 1664.0 63.372618 ... 61.636381 61.636381
|
||
13 1792.0 72.983276 ... 68.953520 68.533074
|
||
14 1920.0 65.516586 ... 67.764707 69.120002
|
||
15 2048.0 73.584279 ... 75.234154 74.898285
|
||
16 2176.0 82.473969 ... 78.302130 78.608000
|
||
17 2304.0 68.446623 ... 73.051599 72.607513
|
||
18 2432.0 71.305746 ... 80.499895 80.963875
|
||
19 2560.0 78.019048 ... 76.740048 74.983980
|
||
20 2688.0 83.737433 ... 83.552988 80.537273
|
||
21 2816.0 78.868366 ... 77.193289 78.442822
|
||
22 2944.0 81.166173 ... 80.122235 77.385141
|
||
23 3072.0 81.472093 ... 83.391907 81.121923
|
||
24 3200.0 82.368085 ... 86.021503 87.671229
|
||
25 3328.0 83.516586 ... 86.320498 82.939284
|
||
26 3456.0 81.849303 ... 80.783132 85.043848
|
||
27 3584.0 87.211821 ... 92.126428 84.663603
|
||
28 3712.0 84.088676 ... 84.301560 81.550936
|
||
29 3840.0 81.079177 ... 84.164384 85.267542
|
||
30 3968.0 92.442373 ... 84.038524 83.807647
|
||
31 4096.0 93.401342 ... 91.304576 91.056800
|
||
|
||
[32 rows x 5 columns]
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 10.315 seconds)</p>
|
||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||
</div>
|
||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/b51b68bc1c6b1a5e509f67800b6235af/03-matrix-multiplication.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">03-matrix-multiplication.ipynb</span></code></a></p>
|
||
</div>
|
||
</div>
|
||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
<a href="02-fused-softmax.html" class="btn btn-neutral float-left" title="Fused Softmax" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |