[FRONTEND] use unsigned integers to simplify RNG (#417)
This commit is contained in:
committed by
GitHub
parent
001fb757fe
commit
120cda015e
@@ -2,11 +2,6 @@ import triton
|
|||||||
from . import core as tl
|
from . import core as tl
|
||||||
|
|
||||||
|
|
||||||
# Notes
|
|
||||||
# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations.
|
|
||||||
# 2. multiply_low_high is currently inefficient.
|
|
||||||
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
|
|
||||||
|
|
||||||
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||||
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||||
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||||
@@ -17,9 +12,6 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
|||||||
# randint
|
# randint
|
||||||
# -------------------
|
# -------------------
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def hacky_to_uint64(x):
|
|
||||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
@@ -40,6 +32,7 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
k1 = k1 + PHILOX_KEY_B
|
k1 = k1 + PHILOX_KEY_B
|
||||||
return c0, c1, c2, c3
|
return c0, c1, c2, c3
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
@@ -55,6 +48,7 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
@@ -67,11 +61,10 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
|
z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting
|
||||||
seed = seed + 0
|
seed = seed.to(tl.uint64)
|
||||||
seed = hacky_to_uint64(seed) # uint will solve this
|
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
||||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
seed_lo = (seed & 0xffffffff).to(tl.uint32)
|
||||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
|
||||||
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
|
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,13 +75,11 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def uint32_to_uniform_float(x):
|
def uint32_to_uniform_float(x):
|
||||||
"""
|
"""
|
||||||
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
|
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||||
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
|
||||||
covers all the possible values it can take.
|
|
||||||
"""
|
"""
|
||||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
two_to_the_minus_32 = 2.328306e-10
|
||||||
x = tl.where(x < 0, -x - 1, x)
|
return x * two_to_the_minus_32
|
||||||
return x * max
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
@@ -102,6 +93,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
source = randint(seed, offset, n_rounds)
|
source = randint(seed, offset, n_rounds)
|
||||||
return uint32_to_uniform_float(source)
|
return uint32_to_uniform_float(source)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
@@ -122,6 +114,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
# randn
|
# randn
|
||||||
# -------------------
|
# -------------------
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def pair_uniform_to_normal(u1, u2):
|
def pair_uniform_to_normal(u1, u2):
|
||||||
"""Box-Muller transform"""
|
"""Box-Muller transform"""
|
||||||
@@ -130,6 +123,7 @@ def pair_uniform_to_normal(u1, u2):
|
|||||||
r = tl.sqrt(-2.0 * tl.log(u1))
|
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||||
return r * tl.cos(th), r * tl.sin(th)
|
return r * tl.cos(th), r * tl.sin(th)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
@@ -145,6 +139,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
n1, _ = pair_uniform_to_normal(u1, u2)
|
n1, _ = pair_uniform_to_normal(u1, u2)
|
||||||
return n1
|
return n1
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user