{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Low-Memory Dropout\n\nIn this tutorial, you will write a memory-efficient implementation of dropout whose state\nwill be composed of a single int32 seed. This differs from more traditional implementations of dropout,\nwhose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:\n\n- The limitations of naive implementations of Dropout with PyTorch\n- Parallel pseudo-random number generation in Triton\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Baseline\nThe *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance\nof deep neural networks in low-data regime (i.e. regularization).\n\nIt takes a vector as input and produces a vector of the same shape as output. Each scalar in the\noutput has a probability $p$ of being changed to zero and otherwise it is copied from the input.\nThis forces the network to perform well even when only $1 - p$ scalars from the input are available.\n\nAt evaluation time we want to use the full power of the network so we set $p=0$. Naively this would\nincrease the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease\nin the output softmax temperature). To prevent this we multiply the output by $\\frac{1}{1 - p}$, which\nkeeps the norm consistent regardless of the dropout probability.\n\nLet's first take a look at the baseline implementation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tabulate\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n# Input tensor\nx = torch.randn(size=(10,)).cuda()\n# Dropout mask\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\n#\noutput = dropout(x, x_keep=x_keep, p=p)\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"keep mask\"] + x_keep.tolist(),\n [\"output\"] + output.tolist()\n]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Seeded dropout\nAbove implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly\nwe need to store the dropout mask for backpropagation. Secondly, dropout state management can get\nvery tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in\nhttps://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation\nthat (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management\nof persisting randomness across multiple invocations of the kernel.\n\nPseudorandom number generation in Triton is simple! In this tutorial we will use the\n:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`\nvalues in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides\nother `random number generation strategies `.\n\n

Note

Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).

\n\nLet's put it all together.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\nx = torch.randn(size=(10,)).cuda()\n# Compare this to the baseline - dropout mask is never instantiated!\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\n\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"output (seed = 123)\"] + output.tolist(),\n [\"output (seed = 123)\"] + output2.tolist(),\n [\"output (seed = 512)\"] + output3.tolist()\n]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Et Voil\u00e0! We have a triton kernel that applies the same dropout mask provided the seed is the same!\nIf you'd like explore further applications of pseudorandomness in GPU programming, we encourage you\nto explore the `triton/language/random` folder!\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercises\n1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.\n2. Add support for striding.\n3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n\n.. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, \"Parallel Random Numbers: As Easy as 1, 2, 3\", 2011\n.. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, \"Dropout: A Simple Way to Prevent Neural Networks from Overfitting\", JMLR 2014\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 }