diff --git a/python/triton/language/random.py b/python/triton/language/random.py index d5691bb72..c95eac9fc 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -77,13 +77,23 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): # rand # ------------------- +# @triton.jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + @triton.jit def uint32_to_uniform_float(x): """ Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ - two_to_the_minus_32: tl.constexpr = 2.328306e-10 - return x * two_to_the_minus_32 + x = x.to(tl.int32, bitcast=True) + max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + x = tl.where(x < 0, -x - 1, x) + return x * max @triton.jit