[LANGUAGE] Fixed issue with duplicates in large arrays of random uniform numbers (#338)
This commit is contained in:
@@ -132,34 +132,6 @@ def test_randint(size, seed, device='cuda'):
|
|||||||
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||||
assert out_tri == out_ref
|
assert out_tri == out_ref
|
||||||
|
|
||||||
# test conversion of random uint32 into random float in [0, 1]
|
|
||||||
def test_uint32_to_uniform_float():
|
|
||||||
@triton.jit
|
|
||||||
def kernel(SRC, TGT, N, **meta):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
offset = pid * BLOCK + tl.arange(0, BLOCK)
|
|
||||||
src = tl.load(SRC + offset)
|
|
||||||
tgt = tl.random.uint32_to_uniform_float(src)
|
|
||||||
tl.store(TGT + offset, tgt, mask=offset < N)
|
|
||||||
|
|
||||||
def run(source):
|
|
||||||
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
|
|
||||||
N = source.numel()
|
|
||||||
grid = lambda meta: (triton.cdiv(N, BLOCK),)
|
|
||||||
kernel[grid](source, target, N)
|
|
||||||
return target
|
|
||||||
|
|
||||||
# check range of edge values
|
|
||||||
n = 100
|
|
||||||
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
|
|
||||||
target = run(source).tolist()
|
|
||||||
assert target == sorted(target)
|
|
||||||
assert all(0.0 <= num < 1.0 for num in target)
|
|
||||||
# check distribution is uniform
|
|
||||||
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
|
|
||||||
target = run(source).tolist()
|
|
||||||
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
|
|
||||||
|
|
||||||
# test uniform PRNG
|
# test uniform PRNG
|
||||||
@pytest.mark.parametrize('size, seed',
|
@pytest.mark.parametrize('size, seed',
|
||||||
[(size, seed) for size in [1000000]\
|
[(size, seed) for size in [1000000]\
|
||||||
|
@@ -100,11 +100,9 @@ def uint32_to_uniform_float(x):
|
|||||||
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
||||||
covers all the possible values it can take.
|
covers all the possible values it can take.
|
||||||
"""
|
"""
|
||||||
mantissa = x & 0x7fffff
|
max = 2147483647.
|
||||||
exp = 127
|
x = tl.where(x < 0, -x - 1, x)
|
||||||
res = mantissa | (exp << 23)
|
return x / max
|
||||||
return res.to(tl.float32, bitcast=True) - 1.0
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def pair_uniform_to_normal(u1, u2):
|
def pair_uniform_to_normal(u1, u2):
|
||||||
|
Reference in New Issue
Block a user