[FRONTEND][RANDOM] Make sure offset dtype is always uint32 before calling uint32_to_uniform_float (#427)
This commit is contained in:
@@ -76,7 +76,7 @@ def uint32_to_uniform_float(x):
|
||||
"""
|
||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
"""
|
||||
two_to_the_minus_32 = 2.328306e-10
|
||||
two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
return x * two_to_the_minus_32
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
offset = offset.to(tl.uint32, bitcast=True)
|
||||
source = randint(seed, offset, n_rounds)
|
||||
return uint32_to_uniform_float(source)
|
||||
|
||||
@@ -102,6 +103,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
offsets = offsets.to(tl.uint32, bitcast=True)
|
||||
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
|
Reference in New Issue
Block a user