2022-07-14 07:22:19 +00:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Libdevice function — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script data-url_root = "../../" id = "documentation_options" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
< link rel = "next" title = "triton" href = "../../python-api/triton.html" / >
< link rel = "prev" title = "Fused Attention" href = "06-fused-attention.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" role = "heading" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2" > < a class = "reference internal" href = "01-vector-add.html" > Vector Addition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "02-fused-softmax.html" > Fused Softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "05-layer-norm.html" > Layer Normalization< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "06-fused-attention.html" > Fused Attention< / a > < / li >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Libdevice function< / a > < ul >
< li class = "toctree-l3" > < a class = "reference internal" href = "#asin-kernel" > asin Kernel< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#using-the-default-libdevice-library-path" > Using the default libdevice library path< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#customize-the-libdevice-library-path" > Customize the libdevice library path< / a > < / li >
< / ul >
< / li >
< / ul >
< / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Python API< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.html" > triton< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.language.html" > triton.language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Libdevice function< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/07-libdevice-function.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-07-libdevice-function-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "libdevice-function" >
< span id = "sphx-glr-getting-started-tutorials-07-libdevice-function-py" > < / span > < h1 > Libdevice function< a class = "headerlink" href = "#libdevice-function" title = "Permalink to this headline" > ¶< / a > < / h1 >
< p > Triton can invoke a custom function from an external library.
In this example, we will use the < cite > libdevice< / cite > library to apply < cite > asin< / cite > on a tensor.
Please refer to < a class = "reference external" href = "https://docs.nvidia.com/cuda/libdevice-users-guide/index.html" > https://docs.nvidia.com/cuda/libdevice-users-guide/index.html< / a > regarding the semantics of all available libdevice functions.< / p >
< p > In < cite > trition/language/libdevice.py< / cite > , we try to aggregate functions with the same computation but different data types together.
For example, both < cite > __nv_asin< / cite > and < cite > __nvasinf< / cite > calculate the principal value of the arc sine of the input, but < cite > __nv_asin< / cite > operates on < cite > double< / cite > and < cite > __nv_asinf< / cite > operates on < cite > float< / cite > .
Using triton, you can simply call < cite > tl.libdevice.asinf< / cite > .
triton automatically selects the correct underlying device function to invoke based on input and output types.< / p >
< div class = "section" id = "asin-kernel" >
< h2 > asin Kernel< a class = "headerlink" href = "#asin-kernel" title = "Permalink to this headline" > ¶< / a > < / h2 >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton.language< / span > < span class = "k" > as< / span > < span class = "nn" > tl< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > asin_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > x_ptr< / span > < span class = "p" > ,< / span >
< span class = "n" > y_ptr< / span > < span class = "p" > ,< / span >
< span class = "n" > n_elements< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > block_start< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_SIZE< / span >
< span class = "n" > offsets< / span > < span class = "o" > =< / span > < span class = "n" > block_start< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > offsets< / span > < span class = "o" > < < / span > < span class = "n" > n_elements< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > x_ptr< / span > < span class = "o" > +< / span > < span class = "n" > offsets< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > libdevice< / span > < span class = "o" > .< / span > < span class = "n" > asin< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > y_ptr< / span > < span class = "o" > +< / span > < span class = "n" > offsets< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< / div >
< div class = "section" id = "using-the-default-libdevice-library-path" >
< h2 > Using the default libdevice library path< a class = "headerlink" href = "#using-the-default-libdevice-library-path" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > We can use the default libdevice library path encoded in < cite > triton/language/libdevice.py< / cite > < / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > manual_seed< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > size< / span > < span class = "o" > =< / span > < span class = "mi" > 98432< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > output_triton< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > output_torch< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > asin< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > is_cuda< / span > < span class = "ow" > and< / span > < span class = "n" > output_triton< / span > < span class = "o" > .< / span > < span class = "n" > is_cuda< / span >
< span class = "n" > n_elements< / span > < span class = "o" > =< / span > < span class = "n" > output_torch< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > meta< / span > < span class = "p" > :< / span > < span class = "p" > (< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE' < / span > < span class = "p" > ]),)< / span >
< span class = "n" > asin_kernel< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > output_triton< / span > < span class = "p" > ,< / span > < span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "mi" > 1024< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > output_torch< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > output_triton< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span >
< span class = "sa" > f< / span > < span class = "s1" > ' The maximum difference between torch and triton is ' < / span >
< span class = "sa" > f< / span > < span class = "s1" > ' < / span > < span class = "si" > {< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > max< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > abs< / span > < span class = "p" > (< / span > < span class = "n" > output_torch< / span > < span class = "o" > -< / span > < span class = "n" > output_triton< / span > < span class = "p" > ))< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span >
< span class = "p" > )< / span >
< / pre > < / div >
< / div >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=' cuda:0' )
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=' cuda:0' )
The maximum difference between torch and triton is 2.384185791015625e-07
< / pre > < / div >
< / div >
< / div >
< div class = "section" id = "customize-the-libdevice-library-path" >
< h2 > Customize the libdevice library path< a class = "headerlink" href = "#customize-the-libdevice-library-path" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > We can also customize the libdevice library path by passing the path to the < cite > libdevice< / cite > library to the < cite > asin< / cite > kernel.< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > output_triton< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > asin_kernel< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > output_triton< / span > < span class = "p" > ,< / span > < span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "mi" > 1024< / span > < span class = "p" > ,< / span >
< span class = "n" > extern_libs< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s1" > ' libdevice' < / span > < span class = "p" > :< / span > < span class = "s1" > ' /usr/local/cuda/nvvm/libdevice/libdevice.10.bc' < / span > < span class = "p" > })< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > output_torch< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > output_triton< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span >
< span class = "sa" > f< / span > < span class = "s1" > ' The maximum difference between torch and triton is ' < / span >
< span class = "sa" > f< / span > < span class = "s1" > ' < / span > < span class = "si" > {< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > max< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > abs< / span > < span class = "p" > (< / span > < span class = "n" > output_torch< / span > < span class = "o" > -< / span > < span class = "n" > output_triton< / span > < span class = "p" > ))< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span >
< span class = "p" > )< / span >
< / pre > < / div >
< / div >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=' cuda:0' )
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=' cuda:0' )
The maximum difference between torch and triton is 2.384185791015625e-07
< / pre > < / div >
< / div >
2022-07-17 00:49:40 +00:00
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 0 minutes 0.010 seconds)< / p >
2022-07-14 07:22:19 +00:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-07-libdevice-function-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 07-libdevice-function.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 07-libdevice-function.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "../../python-api/triton.html" class = "btn btn-neutral float-right" title = "triton" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
< a href = "06-fused-attention.html" class = "btn btn-neutral float-left" title = "Fused Attention" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< div class = "rst-versions" data-toggle = "rst-versions" role = "note" aria-label = "versions" >
< span class = "rst-current-version" data-toggle = "rst-current-version" >
< span class = "fa fa-book" > Other Versions< / span >
v: master
< span class = "fa fa-caret-down" > < / span >
< / span >
< div class = "rst-other-versions" >
< dl >
< dt > Tags< / dt >
< dd > < a href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd >
< / dl >
< dl >
< dt > Branches< / dt >
< dd > < a href = "07-libdevice-function.html" > master< / a > < / dd >
< / dl >
< / div >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >