423 lines
29 KiB
HTML
423 lines
29 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Vector Addition — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||
<script src="../../_static/jquery.js"></script>
|
||
<script src="../../_static/underscore.js"></script>
|
||
<script src="../../_static/doctools.js"></script>
|
||
|
||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
|
||
<link rel="prev" title="Tutorials" href="index.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li><a href="index.html">Tutorials</a> »</li>
|
||
|
||
<li>Vector Addition</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../../_sources/getting-started/tutorials/01-vector-add.rst.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<div class="sphx-glr-download-link-note admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">here</span></a>
|
||
to download the full example code</p>
|
||
</div>
|
||
<div class="sphx-glr-example-title section" id="vector-addition">
|
||
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline">¶</a></h1>
|
||
<p>In this tutorial, you will write a simple vector addition using Triton and learn about:</p>
|
||
<ul class="simple">
|
||
<li><p>The basic programming model of Triton</p></li>
|
||
<li><p>The <cite>triton.jit</cite> decorator, which is used to define Triton kernels.</p></li>
|
||
<li><p>The best practices for validating and benchmarking your custom ops against native reference implementations</p></li>
|
||
</ul>
|
||
<div class="section" id="compute-kernel">
|
||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||
|
||
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">add_kernel</span><span class="p">(</span>
|
||
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to first input vector</span>
|
||
<span class="n">y_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to second input vector</span>
|
||
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to output vector</span>
|
||
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
|
||
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="c1"># Number of elements each program should process</span>
|
||
<span class="c1"># NOTE: `constexpr` so it can be used as a shape value</span>
|
||
<span class="p">):</span>
|
||
<span class="c1"># There are multiple 'program's processing different data. We identify which program</span>
|
||
<span class="c1"># we are here</span>
|
||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># We use a 1D launch grid so axis is 0</span>
|
||
<span class="c1"># This program will process inputs that are offset from the initial data.</span>
|
||
<span class="c1"># for instance, if you had a vector of length 256 and block_size of 64, the programs</span>
|
||
<span class="c1"># would each access the elements [0:64, 64:128, 128:192, 192:256].</span>
|
||
<span class="c1"># Note that offsets is a list of pointers</span>
|
||
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
|
||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
|
||
<span class="c1"># Create a mask to guard memory operations against out-of-bounds accesses</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">n_elements</span>
|
||
<span class="c1"># Load x and y from DRAM, masking out any extra elements in case the input is not a</span>
|
||
<span class="c1"># multiple of the block size</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||
<span class="c1"># Write x + y back to DRAM</span>
|
||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Let’s also declare a helper function to (1) allocate the <cite>z</cite> tensor
|
||
and (2) enqueue the above kernel with appropriate grid/block sizes.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="c1"># We need to preallocate the output</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">y</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">output</span><span class="o">.</span><span class="n">is_cuda</span>
|
||
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
|
||
<span class="c1"># The SPMD launch grid denotes the number of kernel instances that run in parallel.</span>
|
||
<span class="c1"># It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]</span>
|
||
<span class="c1"># In this case, we use a 1D grid where the size is the number of blocks</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]),)</span>
|
||
<span class="c1"># NOTE:</span>
|
||
<span class="c1"># - each torch.tensor object is implicitly converted into a pointer to its first element.</span>
|
||
<span class="c1"># - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel</span>
|
||
<span class="c1"># - don't forget to pass meta-parameters as keywords arguments</span>
|
||
<span class="n">add_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||
<span class="c1"># We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still</span>
|
||
<span class="c1"># running asynchronously at this point.</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We can now use the above function to compute the element-wise sum of two <cite>torch.tensor</cite> objects and test its correctness:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">output_torch</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||
<span class="n">output_triton</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">output_torch</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">output_triton</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span>
|
||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
|
||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
|
||
The maximum difference between torch and triton is 0.0
|
||
</pre></div>
|
||
</div>
|
||
<p>Seems like we’re good to go!</p>
|
||
</div>
|
||
<div class="section" id="benchmark">
|
||
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline">¶</a></h2>
|
||
<p>We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
||
for different problem sizes.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'size'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
|
||
<span class="mi">2</span> <span class="o">**</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
|
||
<span class="n">x_log</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># x axis is logarithmic</span>
|
||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'torch'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg`</span>
|
||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'Triton'</span><span class="p">,</span> <span class="s1">'Torch'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||
<span class="n">ylabel</span><span class="o">=</span><span class="s1">'GB/s'</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||
<span class="n">plot_name</span><span class="o">=</span><span class="s1">'vector-add-performance'</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||
<span class="n">args</span><span class="o">=</span><span class="p">{},</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||
<span class="p">)</span>
|
||
<span class="p">)</span>
|
||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
|
||
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">12</span> <span class="o">*</span> <span class="n">size</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
|
||
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We can now run the decorated function above. Pass <cite>print_data=True</cite> to see the performance number, <cite>show_plots=True</cite> to plot them, and/or
|
||
<a href="#id1"><span class="problematic" id="id2">`</span></a>save_path=’/path/to/results/’ to save them to disk along with raw CSV data</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<img alt="01 vector add" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>vector-add-performance:
|
||
size Triton Torch
|
||
0 4096.0 9.600000 9.600000
|
||
1 8192.0 15.999999 15.999999
|
||
2 16384.0 38.400001 38.400001
|
||
3 32768.0 76.800002 76.800002
|
||
4 65536.0 127.999995 127.999995
|
||
5 131072.0 219.428568 219.428568
|
||
6 262144.0 384.000001 384.000001
|
||
7 524288.0 472.615390 472.615390
|
||
8 1048576.0 614.400016 614.400016
|
||
9 2097152.0 722.823517 722.823517
|
||
10 4194304.0 780.190482 780.190482
|
||
11 8388608.0 812.429770 812.429770
|
||
12 16777216.0 833.084721 833.084721
|
||
13 33554432.0 842.004273 843.811163
|
||
14 67108864.0 847.448255 848.362445
|
||
15 134217728.0 849.737435 850.656574
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 43.794 seconds)</p>
|
||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||
</div>
|
||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">01-vector-add.ipynb</span></code></a></p>
|
||
</div>
|
||
</div>
|
||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
<a href="index.html" class="btn btn-neutral float-left" title="Tutorials" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||
<span class="fa fa-book"> Other Versions</span>
|
||
v: master
|
||
<span class="fa fa-caret-down"></span>
|
||
</span>
|
||
<div class="rst-other-versions">
|
||
<dl>
|
||
<dt>Tags</dt>
|
||
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
|
||
</dl>
|
||
<dl>
|
||
<dt>Branches</dt>
|
||
<dd><a href="01-vector-add.html">master</a></dd>
|
||
</dl>
|
||
</div>
|
||
</div>
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |