[FRONTEND] use unsigned integers to simplify RNG (#417)

This commit is contained in:
Madeleine Thompson
2022-01-06 10:49:09 -08:00
committed by GitHub
parent 001fb757fe
commit 120cda015e

View File

@@ -2,24 +2,16 @@ import triton
from . import core as tl from . import core as tl
# Notes PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations. PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
# 2. multiply_low_high is currently inefficient. PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
# ------------------- # -------------------
# randint # randint
# ------------------- # -------------------
@triton.jit
def hacky_to_uint64(x):
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
@triton.jit @triton.jit
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
@@ -40,12 +32,13 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
k1 = k1 + PHILOX_KEY_B k1 = k1 + PHILOX_KEY_B
return c0, c1, c2, c3 return c0, c1, c2, c3
@triton.jit @triton.jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, returns a single Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`. block of random :code:`int32`.
If you need multiple streams of random numbers, If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times. using `randint4x` is likely to be faster than calling `randint` 4 times.
@@ -55,23 +48,23 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
ret, _, _, _ = randint4x(seed, offset, n_rounds) ret, _, _, _ = randint4x(seed, offset, n_rounds)
return ret return ret
@triton.jit @triton.jit
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, returns four Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`. blocks of random :code:`int32`.
This is the maximally efficient entry point This is the maximally efficient entry point
to Triton's Philox pseudo-random number generator. to Triton's Philox pseudo-random number generator.
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for. :param offsets: The offsets to generate random numbers for.
""" """
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting
seed = seed + 0 seed = seed.to(tl.uint64)
seed = hacky_to_uint64(seed) # uint will solve this seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.uint32)
seed_lo = (seed & 0xffffffff).to(tl.int32)
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds) return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
@@ -82,18 +75,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
@triton.jit @triton.jit
def uint32_to_uniform_float(x): def uint32_to_uniform_float(x):
""" """
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
""" """
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. two_to_the_minus_32 = 2.328306e-10
x = tl.where(x < 0, -x - 1, x) return x * two_to_the_minus_32
return x * max
@triton.jit @triton.jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`U(0, 1)` returns a block of random :code:`float32` in :math:`U(0, 1)`
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
@@ -102,6 +93,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
source = randint(seed, offset, n_rounds) source = randint(seed, offset, n_rounds)
return uint32_to_uniform_float(source) return uint32_to_uniform_float(source)
@triton.jit @triton.jit
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
@@ -122,6 +114,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# randn # randn
# ------------------- # -------------------
@triton.jit @triton.jit
def pair_uniform_to_normal(u1, u2): def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform""" """Box-Muller transform"""
@@ -130,10 +123,11 @@ def pair_uniform_to_normal(u1, u2):
r = tl.sqrt(-2.0 * tl.log(u1)) r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th) return r * tl.cos(th), r * tl.sin(th)
@triton.jit @triton.jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
@@ -145,6 +139,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
n1, _ = pair_uniform_to_normal(u1, u2) n1, _ = pair_uniform_to_normal(u1, u2)
return n1 return n1
@triton.jit @triton.jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """