[FRONTEND] use unsigned integers to simplify RNG (#417)

This commit is contained in:
Madeleine Thompson
2022-01-06 10:49:09 -08:00
committed by GitHub
parent 001fb757fe
commit 120cda015e

View File

@@ -2,11 +2,6 @@ import triton
from . import core as tl
# Notes
# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations.
# 2. multiply_low_high is currently inefficient.
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
@@ -17,9 +12,6 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
# randint
# -------------------
@triton.jit
def hacky_to_uint64(x):
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
@triton.jit
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
@@ -40,6 +32,7 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
k1 = k1 + PHILOX_KEY_B
return c0, c1, c2, c3
@triton.jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
@@ -55,6 +48,7 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
ret, _, _, _ = randint4x(seed, offset, n_rounds)
return ret
@triton.jit
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
@@ -67,11 +61,10 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
seed = seed + 0
seed = hacky_to_uint64(seed) # uint will solve this
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
seed_lo = (seed & 0xffffffff).to(tl.int32)
z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting
seed = seed.to(tl.uint64)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
seed_lo = (seed & 0xffffffff).to(tl.uint32)
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
@@ -82,13 +75,11 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
@triton.jit
def uint32_to_uniform_float(x):
"""
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
"""
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
x = tl.where(x < 0, -x - 1, x)
return x * max
two_to_the_minus_32 = 2.328306e-10
return x * two_to_the_minus_32
@triton.jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
@@ -102,6 +93,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
source = randint(seed, offset, n_rounds)
return uint32_to_uniform_float(source)
@triton.jit
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
@@ -122,6 +114,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# randn
# -------------------
@triton.jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
@@ -130,6 +123,7 @@ def pair_uniform_to_normal(u1, u2):
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
@@ -145,6 +139,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""