{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Libdevice function\nTriton can invoke a custom function from an external library.\nIn this example, we will use the `libdevice` library to apply `asin` on a tensor.\nPlease refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.\n\nIn `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.\nFor example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.\nUsing triton, you can simply call `tl.libdevice.asinf`.\ntriton automatically selects the correct underlying device function to invoke based on input and output types.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## asin Kernel\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n x = tl.libdevice.asin(x)\n tl.store(y_ptr + offsets, x, mask=mask)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using the default libdevice library path\nWe can use the default libdevice library path encoded in `triton/language/libdevice.py`\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\noutput_triton = torch.zeros(size, device='cuda')\noutput_torch = torch.asin(x)\nassert x.is_cuda and output_triton.is_cuda\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Customize the libdevice library path\nWe can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "output_triton = torch.empty_like(x)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,\n extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 }