From 120cda015eaf541ab9f61237e21dffb9688d3b12 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Thu, 6 Jan 2022 10:49:09 -0800 Subject: [PATCH] [FRONTEND] use unsigned integers to simplify RNG (#417) --- python/triton/language/random.py | 61 +++++++++++++++----------------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index e1ac3c30a..cb2ddfc6b 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -2,24 +2,16 @@ import triton from . import core as tl -# Notes -# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations. -# 2. multiply_low_high is currently inefficient. -# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float - -PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 -PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 -PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 -PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57 -N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox +PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 +PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 +PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 +PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57 +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox # ------------------- # randint # ------------------- -@triton.jit -def hacky_to_uint64(x): - return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) @triton.jit def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): @@ -40,12 +32,13 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): k1 = k1 + PHILOX_KEY_B return c0, c1, c2, c3 + @triton.jit def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, returns a single - block of random :code:`int32`. - + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + If you need multiple streams of random numbers, using `randint4x` is likely to be faster than calling `randint` 4 times. @@ -55,23 +48,23 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): ret, _, _, _ = randint4x(seed, offset, n_rounds) return ret + @triton.jit def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, returns four - blocks of random :code:`int32`. - - This is the maximally efficient entry point + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point to Triton's Philox pseudo-random number generator. :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting - seed = seed + 0 - seed = hacky_to_uint64(seed) # uint will solve this - seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) - seed_lo = (seed & 0xffffffff).to(tl.int32) + z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting + seed = seed.to(tl.uint64) + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds) @@ -82,18 +75,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): @triton.jit def uint32_to_uniform_float(x): """ - Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). - This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly - covers all the possible values it can take. + Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ - max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. - x = tl.where(x < 0, -x - 1, x) - return x * max + two_to_the_minus_32 = 2.328306e-10 + return x * two_to_the_minus_32 + @triton.jit def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, + Given a :code:`seed` scalar and an :code:`offset` block, returns a block of random :code:`float32` in :math:`U(0, 1)` :param seed: The seed for generating random numbers. @@ -102,6 +93,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): source = randint(seed, offset, n_rounds) return uint32_to_uniform_float(source) + @triton.jit def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ @@ -122,6 +114,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): # randn # ------------------- + @triton.jit def pair_uniform_to_normal(u1, u2): """Box-Muller transform""" @@ -130,10 +123,11 @@ def pair_uniform_to_normal(u1, u2): r = tl.sqrt(-2.0 * tl.log(u1)) return r * tl.cos(th), r * tl.sin(th) + @triton.jit def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, + Given a :code:`seed` scalar and an :code:`offset` block, returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` :param seed: The seed for generating random numbers. @@ -145,6 +139,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): n1, _ = pair_uniform_to_normal(u1, u2) return n1 + @triton.jit def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """