[LANG] Added seeded random number generation - philox (#261)

This commit is contained in:
Szymon Sidor
2021-09-02 22:02:40 -07:00
committed by GitHub
parent c069ef907e
commit 8bedcce9be
16 changed files with 595 additions and 6 deletions

View File

@@ -66,7 +66,7 @@ def setup(app):
import sys
import os
sys.path.insert(0, os.path.abspath('../python/'))
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
autosummary_generate = True
# Sphinx gallery
@@ -78,6 +78,9 @@ sphinx_gallery_conf = {
'filename_pattern': '',
'ignore_pattern': r'__init__\.py',
'within_subsection_order': FileNameSortKey,
'reference_url': {
'sphinx_gallery': None,
}
}
# Add any paths that contain templates here, relative to this directory.

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View File

@@ -121,6 +121,19 @@ Comparison ops
minimum
maximum
.. _Random Number Generation:
Random Number Generation
-------------------------
.. autosummary::
:toctree: generated
:nosignatures:
randint4x
randint
rand
randn
Compiler Hint Ops
-------------------
@@ -129,4 +142,4 @@ Compiler Hint Ops
:toctree: generated
:nosignatures:
multiple_of
multiple_of

View File

@@ -126,7 +126,7 @@ setup(
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
install_requires=["torch"],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True,

View File

@@ -0,0 +1,198 @@
import torch
import triton
import triton.language as tl
import pytest
import scipy.stats
import numpy as np
from numpy.random import Philox
#####################################
## Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
self.DTYPE = DTYPE
# This is better for GPU
PHILOX_32 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B9,
PHILOX_KEY_B=0xBB67AE85,
PHILOX_ROUND_A=0xD2511F53,
PHILOX_ROUND_B=0xCD9E8D57,
DTYPE=np.uint32,
)
# This is what numpy implements
PHILOX_64 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B97F4A7C15,
PHILOX_KEY_B=0xBB67AE8584CAA73B,
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
PHILOX_ROUND_B=0xCA5A826395121157,
DTYPE=np.uint64,
)
class CustomPhilox4x:
def __init__(self, seed, config):
self._config = config
seed = self._into_pieces(seed)
self._key = np.array(seed[:2], dtype=self._dtype)
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
@property
def _dtype(self):
return self._config.DTYPE
def _into_pieces(self, n, pad=4):
res = []
while len(res) < pad:
res.append(np.array(n, dtype=self._dtype))
n >>= (np.dtype(self._dtype).itemsize * 8)
assert n == 0
return tuple(res)
def _multiply_low_high(self, a, b):
low = a * b
high = int(a) * int(b)
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
return low, high
def _single_round(self, counter, key):
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
ret0 = hi1 ^ counter[1] ^ key[0]
ret1 = lo1
ret2 = hi0 ^ counter[3] ^ key[1]
ret3 = lo0
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key):
ret0 = key[0] + self._config.PHILOX_KEY_A
ret1 = key[1] + self._config.PHILOX_KEY_B
return np.array([ret0, ret1], dtype=self._dtype)
def random_raw(self):
counter = self._counter
key = self._key
for _ in range(10):
counter = self._single_round(counter, key)
key = self._raise_key(key)
self.advance(1)
return counter
def advance(self, n_steps):
self._counter[0] += n_steps
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
class CustomPhilox(CustomPhilox4x):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.buffer = []
def random_raw(self):
if len(self.buffer) == 0:
self.buffer = list(super().random_raw())[::-1]
return int(self.buffer.pop())
#####################################
## Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\
for seed in [0, 42, 124, 54]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.int32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
# reference result
gen = CustomPhilox4x(seed, config=PHILOX_32)
out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref
# test conversion of random uint32 into random float in [0, 1]
def test_uint32_to_uniform_float():
@triton.jit
def kernel(SRC, TGT, N, **meta):
pid = tl.program_id(0)
offset = pid * BLOCK + tl.arange(0, BLOCK)
src = tl.load(SRC + offset)
tgt = tl.random.uint32_to_uniform_float(src)
tl.store(TGT + offset, tgt, mask=offset < N)
def run(source):
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
N = source.numel()
grid = lambda meta: (triton.cdiv(N, BLOCK),)
kernel[grid](source, target, N)
return target
# check range of edge values
n = 100
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
target = run(source).tolist()
assert target == sorted(target)
assert all(0.0 <= num < 1.0 for num in target)
# check distribution is uniform
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
target = run(source).tolist()
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2

View File

@@ -9,4 +9,4 @@ from . import code_gen
from . import testing
from . import ops
# version
__version__ = '1.0.0'
__version__ = '1.0.0'

View File

@@ -1,2 +1,4 @@
from . import core
from .core import *
from . import random
from .core import *
from .random import *

View File

@@ -648,6 +648,7 @@ def cdiv(x, div):
"""
return (x + div - 1) // div
@triton.jit
def minimum(x, y):
"""

View File

@@ -0,0 +1,208 @@
import triton
import triton.language as tl
# Notes
# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations.
# 2. multiply_low_high is currently inefficient.
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
@triton.jit
def PHILOX_KEY_A():
# 0x9E3779B9
return -1640531527
@triton.jit
def PHILOX_KEY_B():
# 0xBB67AE85
return -1150833019
@triton.jit
def PHILOX_ROUND_A():
# 0xD2511F53
return -766435501
@triton.jit
def PHILOX_ROUND_B():
# 0xCD9E8D57
return -845247145
@triton.jit
def hacky_to_uint64(x):
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
@triton.jit
def multiply_low_high(a, b):
return (
a * b,
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
)
@triton.jit
def single_round(c0, c1, c2, c3, k0, k1):
A = PHILOX_ROUND_A()
B = PHILOX_ROUND_B()
lo0, hi0 = multiply_low_high(A, c0)
lo1, hi1 = multiply_low_high(B, c2)
return (
hi1 ^ c1 ^ k0,
lo1,
hi0 ^ c3 ^ k1,
lo0,
)
@triton.jit
def raise_key(k0, k1):
return (
k0 + PHILOX_KEY_A(),
k1 + PHILOX_KEY_B(),
)
@triton.jit
def philox_f(c0, c1, c2, c3, k0, k1):
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
return c0, c1, c2, c3
@triton.jit
def uint32_to_uniform_float(x):
"""
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
"""
mantissa = x & 0x7fffff
exp = 127
res = mantissa | (exp << 23)
return res.to(tl.float32, bitcast=True) - 1.0
@triton.jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randint4x(seed, offset):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`.
This is the maximally efficient entry point
to Triton's Philox pseudo-random number generator.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
z = 0
return philox_f(offset, z, z, z, seed, z)
@triton.jit
def randint(seed, offset):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
ret, _, _, _ = randint4x(seed, offset)
return ret
@triton.jit
def rand(seed, offset):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`U(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
source = randint(seed, offset)
return uint32_to_uniform_float(source)
@triton.jit
def randn(seed, offset):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, _, _ = randint4x(seed, offset)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def rand4x(seed, offsets):
"""
Given a :code:`seed` scalar and an :code:`offsets` block,
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, i3, i4 = randint4x(seed, offsets)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
u3 = uint32_to_uniform_float(i3)
u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4
@triton.jit
def randn4x(seed, offset):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
u1, u2, u3, u4 = rand4x(seed, offset)
n1, n2 = pair_uniform_to_normal(u1, u2)
n3, n4 = pair_uniform_to_normal(u3, u4)
return n1, n2, n3, n4

View File

@@ -43,7 +43,7 @@ def add_kernel(
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output)
tl.store(output_ptr + offsets, output, mask=mask)
# %%

View File

@@ -0,0 +1,164 @@
"""
Low-Memory Dropout
=================
In this tutorial, you will write a memory-efficient implementation of dropout whose state
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:
- The limitations of naive implementations of Dropout with PyTorch
- Parallel pseudo-random number generation in Triton
"""
# %%
# Baseline
# -------------
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
# of deep neural networks in low-data regime (i.e. regularization).
#
# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
# output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input.
# This forces the network to perform well even when only :math:`1 - p` scalars from the input are available.
#
# At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would
# increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
# in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which
# keeps the norm consistent regardless of the dropout probability.
#
# Let's first take a look at the baseline implementation.
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
**meta,
):
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
# The line below is the crucial part, described in the paragraph above!
output = tl.where(x_keep, x / (1 - p), 0.0)
# Write-back output
tl.store(output_ptr + offsets, output, mask=mask)
def dropout(x, x_keep, p):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
["input"] + x.tolist(),
["keep mask"] + x_keep.tolist(),
["output"] + output.tolist()
]))
# %%
# Seeded dropout
# -------------
# Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
# we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
# very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in
# https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation
# that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
# of persisting randomness across multiple invocations of the kernel.
#
# Pseudorandom number generation in Triton is simple! In this tutorial we will use the
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
# other :ref:`random number generation strategies <Random Number Generation>`.
#
# .. note::
# Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).
#
# Let's put it all together.
@triton.jit
def _seeded_dropout(
x_ptr,
output_ptr,
n_elements,
p,
seed,
**meta,
):
# compute memory offsets of elements handled by this instance
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# load data from x
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# randomly prune it
random = tl.rand(seed, offsets)
x_keep = random > p
# write-back
output = tl.where(x_keep, x / (1 - p), 0.0)
tl.store(output_ptr + offsets, output, mask=mask)
def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
return output
x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)
print(tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist()
]))
# %%
# Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
# If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
# to explore the `triton/language/random` folder!
# %%
# Exercises
# -------------
# 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.
# 2. Add support for striding.
# 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
# %%
# References
# --------------
#
# .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
# .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014