[LANGUAGE] Fixed issue with duplicates in large arrays of random uniform numbers (#338)

This commit is contained in:
Philippe Tillet
2021-10-10 15:22:34 -07:00
committed by GitHub
parent 9e9d781912
commit c3c0ff0552
2 changed files with 3 additions and 33 deletions

View File

@@ -132,34 +132,6 @@ def test_randint(size, seed, device='cuda'):
out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref
# test conversion of random uint32 into random float in [0, 1]
def test_uint32_to_uniform_float():
@triton.jit
def kernel(SRC, TGT, N, **meta):
pid = tl.program_id(0)
offset = pid * BLOCK + tl.arange(0, BLOCK)
src = tl.load(SRC + offset)
tgt = tl.random.uint32_to_uniform_float(src)
tl.store(TGT + offset, tgt, mask=offset < N)
def run(source):
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
N = source.numel()
grid = lambda meta: (triton.cdiv(N, BLOCK),)
kernel[grid](source, target, N)
return target
# check range of edge values
n = 100
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
target = run(source).tolist()
assert target == sorted(target)
assert all(0.0 <= num < 1.0 for num in target)
# check distribution is uniform
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
target = run(source).tolist()
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\

View File

@@ -100,11 +100,9 @@ def uint32_to_uniform_float(x):
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
"""
mantissa = x & 0x7fffff
exp = 127
res = mantissa | (exp << 23)
return res.to(tl.float32, bitcast=True) - 1.0
max = 2147483647.
x = tl.where(x < 0, -x - 1, x)
return x / max
@triton.jit
def pair_uniform_to_normal(u1, u2):