From ccf9abe0ba081f13d37c9966a466f9984cd92747 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 21 Jan 2022 18:05:55 -0800 Subject: [PATCH] [FRONTEND][RANDOM] Improved backward compatibility of RNG (#438) The unsigned int PR definitely improved our RNG. However, it requires different floating point arithmetics which, means the results are not bit-wise identical to how they were before. This commit revives backward compatibility, but we should change it back to the "right" way later. --- python/triton/language/random.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index d5691bb72..c95eac9fc 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -77,13 +77,23 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): # rand # ------------------- +# @triton.jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + @triton.jit def uint32_to_uniform_float(x): """ Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ - two_to_the_minus_32: tl.constexpr = 2.328306e-10 - return x * two_to_the_minus_32 + x = x.to(tl.int32, bitcast=True) + max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + x = tl.where(x < 0, -x - 1, x) + return x * max @triton.jit