[LANGUAGE] Fixed issue with duplicates in large arrays of random uniform numbers (#338)

This commit is contained in:
Philippe Tillet
2021-10-10 15:22:34 -07:00
committed by GitHub
parent 9e9d781912
commit c3c0ff0552
2 changed files with 3 additions and 33 deletions

View File

@@ -132,34 +132,6 @@ def test_randint(size, seed, device='cuda'):
out_ref = [gen.random_raw()[0] for _ in out_tri] out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref assert out_tri == out_ref
# test conversion of random uint32 into random float in [0, 1]
def test_uint32_to_uniform_float():
@triton.jit
def kernel(SRC, TGT, N, **meta):
pid = tl.program_id(0)
offset = pid * BLOCK + tl.arange(0, BLOCK)
src = tl.load(SRC + offset)
tgt = tl.random.uint32_to_uniform_float(src)
tl.store(TGT + offset, tgt, mask=offset < N)
def run(source):
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
N = source.numel()
grid = lambda meta: (triton.cdiv(N, BLOCK),)
kernel[grid](source, target, N)
return target
# check range of edge values
n = 100
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
target = run(source).tolist()
assert target == sorted(target)
assert all(0.0 <= num < 1.0 for num in target)
# check distribution is uniform
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
target = run(source).tolist()
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
# test uniform PRNG # test uniform PRNG
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\ [(size, seed) for size in [1000000]\

View File

@@ -100,11 +100,9 @@ def uint32_to_uniform_float(x):
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take. covers all the possible values it can take.
""" """
mantissa = x & 0x7fffff max = 2147483647.
exp = 127 x = tl.where(x < 0, -x - 1, x)
res = mantissa | (exp << 23) return x / max
return res.to(tl.float32, bitcast=True) - 1.0
@triton.jit @triton.jit
def pair_uniform_to_normal(u1, u2): def pair_uniform_to_normal(u1, u2):