[FRONTEND][RANDOM] Improved backward compatibility of RNG (#438)
The unsigned int PR definitely improved our RNG. However, it requires different floating point arithmetics which, means the results are not bit-wise identical to how they were before. This commit revives backward compatibility, but we should change it back to the "right" way later.
This commit is contained in:
@@ -77,13 +77,23 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
# rand
|
# rand
|
||||||
# -------------------
|
# -------------------
|
||||||
|
|
||||||
|
# @triton.jit
|
||||||
|
# def uint32_to_uniform_float(x):
|
||||||
|
# """
|
||||||
|
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||||
|
# """
|
||||||
|
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||||
|
# return x * two_to_the_minus_32
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def uint32_to_uniform_float(x):
|
def uint32_to_uniform_float(x):
|
||||||
"""
|
"""
|
||||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||||
"""
|
"""
|
||||||
two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
x = x.to(tl.int32, bitcast=True)
|
||||||
return x * two_to_the_minus_32
|
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
||||||
|
x = tl.where(x < 0, -x - 1, x)
|
||||||
|
return x * max
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
Reference in New Issue
Block a user