diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 042065403..39ae59e35 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -175,3 +175,24 @@ def test_randn(size, seed, device='cuda'): kernel[grid](x, N, seed) assert abs(x.mean()) < 1e-2 assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + +def test_rand_limits(): + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint32_to_uniform_float(x) + tl.store(output + idx, y) + + min_max_int32 = torch.tensor([ + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + ], dtype=torch.int32, device='cuda') + output = torch.empty(2, dtype=torch.float32, device='cuda') + kernel[(1,)](min_max_int32, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/python/triton/language/random.py b/python/triton/language/random.py index c95eac9fc..67de92c43 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -91,9 +91,10 @@ def uint32_to_uniform_float(x): Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ x = x.to(tl.int32, bitcast=True) - max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + scale = 4.6566127342e-10 x = tl.where(x < 0, -x - 1, x) - return x * max + return x * scale @triton.jit