From c3c0ff0552515af0eb995d67b95828b8cde50e33 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 10 Oct 2021 15:22:34 -0700 Subject: [PATCH] [LANGUAGE] Fixed issue with duplicates in large arrays of random uniform numbers (#338) --- python/test/unit/language/test_random.py | 28 ------------------------ python/triton/language/random.py | 8 +++---- 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index a7f178f02..4c1261f1d 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -132,34 +132,6 @@ def test_randint(size, seed, device='cuda'): out_ref = [gen.random_raw()[0] for _ in out_tri] assert out_tri == out_ref -# test conversion of random uint32 into random float in [0, 1] -def test_uint32_to_uniform_float(): - @triton.jit - def kernel(SRC, TGT, N, **meta): - pid = tl.program_id(0) - offset = pid * BLOCK + tl.arange(0, BLOCK) - src = tl.load(SRC + offset) - tgt = tl.random.uint32_to_uniform_float(src) - tl.store(TGT + offset, tgt, mask=offset < N) - - def run(source): - target = -torch.ones(source.shape, dtype=torch.float32, device=source.device) - N = source.numel() - grid = lambda meta: (triton.cdiv(N, BLOCK),) - kernel[grid](source, target, N) - return target - - # check range of edge values - n = 100 - source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda() - target = run(source).tolist() - assert target == sorted(target) - assert all(0.0 <= num < 1.0 for num in target) - # check distribution is uniform - source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda() - target = run(source).tolist() - assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01 - # test uniform PRNG @pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000]\ diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 414f61cc0..dbc16e35d 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -100,11 +100,9 @@ def uint32_to_uniform_float(x): This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly covers all the possible values it can take. """ - mantissa = x & 0x7fffff - exp = 127 - res = mantissa | (exp << 23) - return res.to(tl.float32, bitcast=True) - 1.0 - + max = 2147483647. + x = tl.where(x < 0, -x - 1, x) + return x / max @triton.jit def pair_uniform_to_normal(u1, u2):