[FRONTEND][RANDOM] Make sure offset dtype is always uint32 before calling uint32_to_uniform_float (#427)
This commit is contained in:
@@ -76,7 +76,7 @@ def uint32_to_uniform_float(x):
|
|||||||
"""
|
"""
|
||||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||||
"""
|
"""
|
||||||
two_to_the_minus_32 = 2.328306e-10
|
two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||||
return x * two_to_the_minus_32
|
return x * two_to_the_minus_32
|
||||||
|
|
||||||
|
|
||||||
@@ -89,6 +89,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
|
offset = offset.to(tl.uint32, bitcast=True)
|
||||||
source = randint(seed, offset, n_rounds)
|
source = randint(seed, offset, n_rounds)
|
||||||
return uint32_to_uniform_float(source)
|
return uint32_to_uniform_float(source)
|
||||||
|
|
||||||
@@ -102,6 +103,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
|
offsets = offsets.to(tl.uint32, bitcast=True)
|
||||||
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
||||||
u1 = uint32_to_uniform_float(i1)
|
u1 = uint32_to_uniform_float(i1)
|
||||||
u2 = uint32_to_uniform_float(i2)
|
u2 = uint32_to_uniform_float(i2)
|
||||||
|
Reference in New Issue
Block a user