From bbc78f651600735bbcc98de3998b837dd0ce68c1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 11 Jan 2022 11:08:49 -0800 Subject: [PATCH] [FRONTEND][RANDOM] Make sure offset dtype is always uint32 before calling uint32_to_uniform_float (#427) --- python/triton/language/random.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 6f3645b41..69d7f4c4d 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -76,7 +76,7 @@ def uint32_to_uniform_float(x): """ Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ - two_to_the_minus_32 = 2.328306e-10 + two_to_the_minus_32: tl.constexpr = 2.328306e-10 return x * two_to_the_minus_32 @@ -89,6 +89,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ + offset = offset.to(tl.uint32, bitcast=True) source = randint(seed, offset, n_rounds) return uint32_to_uniform_float(source) @@ -102,6 +103,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ + offsets = offsets.to(tl.uint32, bitcast=True) i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) u1 = uint32_to_uniform_float(i1) u2 = uint32_to_uniform_float(i2)