[LANG] Added seeded random number generation - philox (#261)
This commit is contained in:
@@ -66,7 +66,7 @@ def setup(app):
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.insert(0, os.path.abspath('../python/'))
|
sys.path.insert(0, os.path.abspath('../python/'))
|
||||||
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
|
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
|
||||||
autosummary_generate = True
|
autosummary_generate = True
|
||||||
|
|
||||||
# Sphinx gallery
|
# Sphinx gallery
|
||||||
@@ -78,6 +78,9 @@ sphinx_gallery_conf = {
|
|||||||
'filename_pattern': '',
|
'filename_pattern': '',
|
||||||
'ignore_pattern': r'__init__\.py',
|
'ignore_pattern': r'__init__\.py',
|
||||||
'within_subsection_order': FileNameSortKey,
|
'within_subsection_order': FileNameSortKey,
|
||||||
|
'reference_url': {
|
||||||
|
'sphinx_gallery': None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
BIN
docs/getting-started/tutorials/random_bits.png
Normal file
BIN
docs/getting-started/tutorials/random_bits.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
@@ -121,6 +121,19 @@ Comparison ops
|
|||||||
minimum
|
minimum
|
||||||
maximum
|
maximum
|
||||||
|
|
||||||
|
.. _Random Number Generation:
|
||||||
|
|
||||||
|
Random Number Generation
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
randint4x
|
||||||
|
randint
|
||||||
|
rand
|
||||||
|
randn
|
||||||
|
|
||||||
Compiler Hint Ops
|
Compiler Hint Ops
|
||||||
-------------------
|
-------------------
|
||||||
|
@@ -126,7 +126,7 @@ setup(
|
|||||||
author_email="phil@openai.com",
|
author_email="phil@openai.com",
|
||||||
description="A language and compiler for custom Deep Learning operations",
|
description="A language and compiler for custom Deep Learning operations",
|
||||||
long_description="",
|
long_description="",
|
||||||
packages=["triton", "triton/_C", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||||
install_requires=["torch"],
|
install_requires=["torch"],
|
||||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
198
python/test/language/test_random.py
Normal file
198
python/test/language/test_random.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
import pytest
|
||||||
|
import scipy.stats
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from numpy.random import Philox
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
## Reference Philox Implementation
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
class PhiloxConfig:
|
||||||
|
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||||
|
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||||
|
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
|
||||||
|
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
|
||||||
|
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
|
||||||
|
self.DTYPE = DTYPE
|
||||||
|
|
||||||
|
|
||||||
|
# This is better for GPU
|
||||||
|
PHILOX_32 = PhiloxConfig(
|
||||||
|
PHILOX_KEY_A=0x9E3779B9,
|
||||||
|
PHILOX_KEY_B=0xBB67AE85,
|
||||||
|
PHILOX_ROUND_A=0xD2511F53,
|
||||||
|
PHILOX_ROUND_B=0xCD9E8D57,
|
||||||
|
DTYPE=np.uint32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is what numpy implements
|
||||||
|
PHILOX_64 = PhiloxConfig(
|
||||||
|
PHILOX_KEY_A=0x9E3779B97F4A7C15,
|
||||||
|
PHILOX_KEY_B=0xBB67AE8584CAA73B,
|
||||||
|
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
|
||||||
|
PHILOX_ROUND_B=0xCA5A826395121157,
|
||||||
|
DTYPE=np.uint64,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPhilox4x:
|
||||||
|
def __init__(self, seed, config):
|
||||||
|
self._config = config
|
||||||
|
seed = self._into_pieces(seed)
|
||||||
|
self._key = np.array(seed[:2], dtype=self._dtype)
|
||||||
|
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dtype(self):
|
||||||
|
return self._config.DTYPE
|
||||||
|
|
||||||
|
def _into_pieces(self, n, pad=4):
|
||||||
|
res = []
|
||||||
|
while len(res) < pad:
|
||||||
|
res.append(np.array(n, dtype=self._dtype))
|
||||||
|
n >>= (np.dtype(self._dtype).itemsize * 8)
|
||||||
|
assert n == 0
|
||||||
|
return tuple(res)
|
||||||
|
|
||||||
|
def _multiply_low_high(self, a, b):
|
||||||
|
low = a * b
|
||||||
|
high = int(a) * int(b)
|
||||||
|
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
|
||||||
|
return low, high
|
||||||
|
|
||||||
|
def _single_round(self, counter, key):
|
||||||
|
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
|
||||||
|
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
|
||||||
|
ret0 = hi1 ^ counter[1] ^ key[0]
|
||||||
|
ret1 = lo1
|
||||||
|
ret2 = hi0 ^ counter[3] ^ key[1]
|
||||||
|
ret3 = lo0
|
||||||
|
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||||
|
|
||||||
|
def _raise_key(self, key):
|
||||||
|
ret0 = key[0] + self._config.PHILOX_KEY_A
|
||||||
|
ret1 = key[1] + self._config.PHILOX_KEY_B
|
||||||
|
return np.array([ret0, ret1], dtype=self._dtype)
|
||||||
|
|
||||||
|
def random_raw(self):
|
||||||
|
counter = self._counter
|
||||||
|
key = self._key
|
||||||
|
for _ in range(10):
|
||||||
|
counter = self._single_round(counter, key)
|
||||||
|
key = self._raise_key(key)
|
||||||
|
self.advance(1)
|
||||||
|
return counter
|
||||||
|
|
||||||
|
def advance(self, n_steps):
|
||||||
|
self._counter[0] += n_steps
|
||||||
|
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPhilox(CustomPhilox4x):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.buffer = []
|
||||||
|
|
||||||
|
def random_raw(self):
|
||||||
|
if len(self.buffer) == 0:
|
||||||
|
self.buffer = list(super().random_raw())[::-1]
|
||||||
|
return int(self.buffer.pop())
|
||||||
|
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
## Unit Tests
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
BLOCK = 1024
|
||||||
|
|
||||||
|
# test generation of random uint32
|
||||||
|
@pytest.mark.parametrize('size, seed',
|
||||||
|
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||||
|
for seed in [0, 42, 124, 54]]
|
||||||
|
)
|
||||||
|
def test_randint(size, seed, device='cuda'):
|
||||||
|
size = list(map(int, size.split(',')))
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, N, seed):
|
||||||
|
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||||
|
rand = tl.randint(seed, offset)
|
||||||
|
tl.store(X + offset, rand, mask=offset < N)
|
||||||
|
# triton result
|
||||||
|
x = torch.empty(size, dtype=torch.int32, device=device)
|
||||||
|
N = x.numel()
|
||||||
|
grid = (triton.cdiv(N, BLOCK),)
|
||||||
|
kernel[grid](x, N, seed)
|
||||||
|
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
|
||||||
|
# reference result
|
||||||
|
gen = CustomPhilox4x(seed, config=PHILOX_32)
|
||||||
|
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||||
|
assert out_tri == out_ref
|
||||||
|
|
||||||
|
# test conversion of random uint32 into random float in [0, 1]
|
||||||
|
def test_uint32_to_uniform_float():
|
||||||
|
@triton.jit
|
||||||
|
def kernel(SRC, TGT, N, **meta):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
offset = pid * BLOCK + tl.arange(0, BLOCK)
|
||||||
|
src = tl.load(SRC + offset)
|
||||||
|
tgt = tl.random.uint32_to_uniform_float(src)
|
||||||
|
tl.store(TGT + offset, tgt, mask=offset < N)
|
||||||
|
|
||||||
|
def run(source):
|
||||||
|
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
|
||||||
|
N = source.numel()
|
||||||
|
grid = lambda meta: (triton.cdiv(N, BLOCK),)
|
||||||
|
kernel[grid](source, target, N)
|
||||||
|
return target
|
||||||
|
|
||||||
|
# check range of edge values
|
||||||
|
n = 100
|
||||||
|
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
|
||||||
|
target = run(source).tolist()
|
||||||
|
assert target == sorted(target)
|
||||||
|
assert all(0.0 <= num < 1.0 for num in target)
|
||||||
|
# check distribution is uniform
|
||||||
|
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
|
||||||
|
target = run(source).tolist()
|
||||||
|
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
|
||||||
|
|
||||||
|
# test uniform PRNG
|
||||||
|
@pytest.mark.parametrize('size, seed',
|
||||||
|
[(size, seed) for size in [1000000]\
|
||||||
|
for seed in [0, 42, 124, 54]]
|
||||||
|
)
|
||||||
|
def test_rand(size, seed, device='cuda'):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, N, seed):
|
||||||
|
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||||
|
rand = tl.rand(seed, offset)
|
||||||
|
tl.store(X + offset, rand, mask=offset < N)
|
||||||
|
# triton result
|
||||||
|
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||||
|
N = x.numel()
|
||||||
|
grid = (triton.cdiv(N, BLOCK),)
|
||||||
|
kernel[grid](x, N, seed)
|
||||||
|
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||||
|
|
||||||
|
# test normal PRNG
|
||||||
|
@pytest.mark.parametrize('size, seed',
|
||||||
|
[(size, seed) for size in [1000000]\
|
||||||
|
for seed in [0, 42, 124, 54]]
|
||||||
|
)
|
||||||
|
def test_randn(size, seed, device='cuda'):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, N, seed):
|
||||||
|
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||||
|
rand = tl.randn(seed, offset)
|
||||||
|
tl.store(X + offset, rand, mask=offset < N)
|
||||||
|
# triton result
|
||||||
|
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||||
|
N = x.numel()
|
||||||
|
grid = (triton.cdiv(N, BLOCK),)
|
||||||
|
kernel[grid](x, N, seed)
|
||||||
|
assert abs(x.mean()) < 1e-2
|
||||||
|
assert abs(x.std() - 1) < 1e-2
|
@@ -1,2 +1,4 @@
|
|||||||
from . import core
|
from . import core
|
||||||
|
from . import random
|
||||||
from .core import *
|
from .core import *
|
||||||
|
from .random import *
|
||||||
|
@@ -648,6 +648,7 @@ def cdiv(x, div):
|
|||||||
"""
|
"""
|
||||||
return (x + div - 1) // div
|
return (x + div - 1) // div
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def minimum(x, y):
|
def minimum(x, y):
|
||||||
"""
|
"""
|
||||||
|
208
python/triton/language/random.py
Normal file
208
python/triton/language/random.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations.
|
||||||
|
# 2. multiply_low_high is currently inefficient.
|
||||||
|
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def PHILOX_KEY_A():
|
||||||
|
# 0x9E3779B9
|
||||||
|
return -1640531527
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def PHILOX_KEY_B():
|
||||||
|
# 0xBB67AE85
|
||||||
|
return -1150833019
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def PHILOX_ROUND_A():
|
||||||
|
# 0xD2511F53
|
||||||
|
return -766435501
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def PHILOX_ROUND_B():
|
||||||
|
# 0xCD9E8D57
|
||||||
|
return -845247145
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def hacky_to_uint64(x):
|
||||||
|
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def multiply_low_high(a, b):
|
||||||
|
return (
|
||||||
|
a * b,
|
||||||
|
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def single_round(c0, c1, c2, c3, k0, k1):
|
||||||
|
A = PHILOX_ROUND_A()
|
||||||
|
B = PHILOX_ROUND_B()
|
||||||
|
lo0, hi0 = multiply_low_high(A, c0)
|
||||||
|
lo1, hi1 = multiply_low_high(B, c2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
hi1 ^ c1 ^ k0,
|
||||||
|
lo1,
|
||||||
|
hi0 ^ c3 ^ k1,
|
||||||
|
lo0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def raise_key(k0, k1):
|
||||||
|
return (
|
||||||
|
k0 + PHILOX_KEY_A(),
|
||||||
|
k1 + PHILOX_KEY_B(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def philox_f(c0, c1, c2, c3, k0, k1):
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
k0, k1 = raise_key(k0, k1)
|
||||||
|
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||||
|
return c0, c1, c2, c3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def uint32_to_uniform_float(x):
|
||||||
|
"""
|
||||||
|
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
|
||||||
|
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
||||||
|
covers all the possible values it can take.
|
||||||
|
"""
|
||||||
|
mantissa = x & 0x7fffff
|
||||||
|
exp = 127
|
||||||
|
res = mantissa | (exp << 23)
|
||||||
|
return res.to(tl.float32, bitcast=True) - 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def pair_uniform_to_normal(u1, u2):
|
||||||
|
"""Box-Muller transform"""
|
||||||
|
u1 = tl.maximum(1.0e-7, u1)
|
||||||
|
th = 6.283185307179586 * u2
|
||||||
|
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||||
|
return r * tl.cos(th), r * tl.sin(th)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def randint4x(seed, offset):
|
||||||
|
"""
|
||||||
|
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
||||||
|
blocks of random :code:`int32`.
|
||||||
|
|
||||||
|
This is the maximally efficient entry point
|
||||||
|
to Triton's Philox pseudo-random number generator.
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
"""
|
||||||
|
z = 0
|
||||||
|
return philox_f(offset, z, z, z, seed, z)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def randint(seed, offset):
|
||||||
|
"""
|
||||||
|
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
||||||
|
block of random :code:`int32`.
|
||||||
|
|
||||||
|
If you need multiple streams of random numbers,
|
||||||
|
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
"""
|
||||||
|
ret, _, _, _ = randint4x(seed, offset)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def rand(seed, offset):
|
||||||
|
"""
|
||||||
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
|
returns a block of random :code:`float32` in :math:`U(0, 1)`
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
"""
|
||||||
|
source = randint(seed, offset)
|
||||||
|
return uint32_to_uniform_float(source)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def randn(seed, offset):
|
||||||
|
"""
|
||||||
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
|
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
"""
|
||||||
|
i1, i2, _, _ = randint4x(seed, offset)
|
||||||
|
u1 = uint32_to_uniform_float(i1)
|
||||||
|
u2 = uint32_to_uniform_float(i2)
|
||||||
|
n1, _ = pair_uniform_to_normal(u1, u2)
|
||||||
|
return n1
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def rand4x(seed, offsets):
|
||||||
|
"""
|
||||||
|
Given a :code:`seed` scalar and an :code:`offsets` block,
|
||||||
|
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
"""
|
||||||
|
i1, i2, i3, i4 = randint4x(seed, offsets)
|
||||||
|
u1 = uint32_to_uniform_float(i1)
|
||||||
|
u2 = uint32_to_uniform_float(i2)
|
||||||
|
u3 = uint32_to_uniform_float(i3)
|
||||||
|
u4 = uint32_to_uniform_float(i4)
|
||||||
|
return u1, u2, u3, u4
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def randn4x(seed, offset):
|
||||||
|
"""
|
||||||
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
|
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
"""
|
||||||
|
u1, u2, u3, u4 = rand4x(seed, offset)
|
||||||
|
n1, n2 = pair_uniform_to_normal(u1, u2)
|
||||||
|
n3, n4 = pair_uniform_to_normal(u3, u4)
|
||||||
|
return n1, n2, n3, n4
|
@@ -43,7 +43,7 @@ def add_kernel(
|
|||||||
y = tl.load(y_ptr + offsets, mask=mask)
|
y = tl.load(y_ptr + offsets, mask=mask)
|
||||||
output = x + y
|
output = x + y
|
||||||
# Write x + y back to DRAM
|
# Write x + y back to DRAM
|
||||||
tl.store(output_ptr + offsets, output)
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
164
python/tutorials/04-low-memory-dropout.py
Normal file
164
python/tutorials/04-low-memory-dropout.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""
|
||||||
|
Low-Memory Dropout
|
||||||
|
=================
|
||||||
|
|
||||||
|
In this tutorial, you will write a memory-efficient implementation of dropout whose state
|
||||||
|
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
|
||||||
|
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:
|
||||||
|
|
||||||
|
- The limitations of naive implementations of Dropout with PyTorch
|
||||||
|
- Parallel pseudo-random number generation in Triton
|
||||||
|
"""
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Baseline
|
||||||
|
# -------------
|
||||||
|
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
|
||||||
|
# of deep neural networks in low-data regime (i.e. regularization).
|
||||||
|
#
|
||||||
|
# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
|
||||||
|
# output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input.
|
||||||
|
# This forces the network to perform well even when only :math:`1 - p` scalars from the input are available.
|
||||||
|
#
|
||||||
|
# At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would
|
||||||
|
# increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
|
||||||
|
# in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which
|
||||||
|
# keeps the norm consistent regardless of the dropout probability.
|
||||||
|
#
|
||||||
|
# Let's first take a look at the baseline implementation.
|
||||||
|
|
||||||
|
|
||||||
|
import tabulate
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _dropout(
|
||||||
|
x_ptr, # pointer to the input
|
||||||
|
x_keep_ptr, # pointer to a mask of 0s and 1s
|
||||||
|
output_ptr, # pointer to the output
|
||||||
|
n_elements, # number of elements in the `x` tensor
|
||||||
|
p, # probability that an element of `x` is changed to zero
|
||||||
|
**meta,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
# Load data
|
||||||
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
|
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
|
||||||
|
# The line below is the crucial part, described in the paragraph above!
|
||||||
|
output = tl.where(x_keep, x / (1 - p), 0.0)
|
||||||
|
# Write-back output
|
||||||
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def dropout(x, x_keep, p):
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
assert x.is_contiguous()
|
||||||
|
n_elements = x.numel()
|
||||||
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
|
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Input tensor
|
||||||
|
x = torch.randn(size=(10,)).cuda()
|
||||||
|
# Dropout mask
|
||||||
|
p = 0.5
|
||||||
|
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
|
||||||
|
#
|
||||||
|
output = dropout(x, x_keep=x_keep, p=p)
|
||||||
|
print(tabulate.tabulate([
|
||||||
|
["input"] + x.tolist(),
|
||||||
|
["keep mask"] + x_keep.tolist(),
|
||||||
|
["output"] + output.tolist()
|
||||||
|
]))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Seeded dropout
|
||||||
|
# -------------
|
||||||
|
# Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
|
||||||
|
# we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
|
||||||
|
# very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in
|
||||||
|
# https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation
|
||||||
|
# that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
|
||||||
|
# of persisting randomness across multiple invocations of the kernel.
|
||||||
|
#
|
||||||
|
# Pseudorandom number generation in Triton is simple! In this tutorial we will use the
|
||||||
|
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
|
||||||
|
# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
|
||||||
|
# other :ref:`random number generation strategies <Random Number Generation>`.
|
||||||
|
#
|
||||||
|
# .. note::
|
||||||
|
# Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).
|
||||||
|
#
|
||||||
|
# Let's put it all together.
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _seeded_dropout(
|
||||||
|
x_ptr,
|
||||||
|
output_ptr,
|
||||||
|
n_elements,
|
||||||
|
p,
|
||||||
|
seed,
|
||||||
|
**meta,
|
||||||
|
):
|
||||||
|
# compute memory offsets of elements handled by this instance
|
||||||
|
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
# load data from x
|
||||||
|
mask = offsets < n_elements
|
||||||
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
|
# randomly prune it
|
||||||
|
random = tl.rand(seed, offsets)
|
||||||
|
x_keep = random > p
|
||||||
|
# write-back
|
||||||
|
output = tl.where(x_keep, x / (1 - p), 0.0)
|
||||||
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def seeded_dropout(x, p, seed):
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
assert x.is_contiguous()
|
||||||
|
n_elements = x.numel()
|
||||||
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
|
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(size=(10,)).cuda()
|
||||||
|
# Compare this to the baseline - dropout mask is never instantiated!
|
||||||
|
output = seeded_dropout(x, p=0.5, seed=123)
|
||||||
|
output2 = seeded_dropout(x, p=0.5, seed=123)
|
||||||
|
output3 = seeded_dropout(x, p=0.5, seed=512)
|
||||||
|
|
||||||
|
print(tabulate.tabulate([
|
||||||
|
["input"] + x.tolist(),
|
||||||
|
["output (seed = 123)"] + output.tolist(),
|
||||||
|
["output (seed = 123)"] + output2.tolist(),
|
||||||
|
["output (seed = 512)"] + output3.tolist()
|
||||||
|
]))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
|
||||||
|
# If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
|
||||||
|
# to explore the `triton/language/random` folder!
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Exercises
|
||||||
|
# -------------
|
||||||
|
# 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.
|
||||||
|
# 2. Add support for striding.
|
||||||
|
# 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# References
|
||||||
|
# --------------
|
||||||
|
#
|
||||||
|
# .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
|
||||||
|
# .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014
|
Reference in New Issue
Block a user