[FRONTEND] Make tl.rand() 1-exclusive (#601)
This commit is contained in:
@@ -175,3 +175,24 @@ def test_randn(size, seed, device='cuda'):
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
||||
|
||||
|
||||
# tl.rand() should never produce >=1.0
|
||||
|
||||
def test_rand_limits():
|
||||
@triton.jit
|
||||
def kernel(input, output, n: tl.constexpr):
|
||||
idx = tl.arange(0, n)
|
||||
x = tl.load(input + idx)
|
||||
y = tl.random.uint32_to_uniform_float(x)
|
||||
tl.store(output + idx, y)
|
||||
|
||||
min_max_int32 = torch.tensor([
|
||||
torch.iinfo(torch.int32).min,
|
||||
torch.iinfo(torch.int32).max,
|
||||
], dtype=torch.int32, device='cuda')
|
||||
output = torch.empty(2, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](min_max_int32, output, 2)
|
||||
|
||||
assert output[0] == output[1]
|
||||
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0
|
||||
|
@@ -91,9 +91,10 @@ def uint32_to_uniform_float(x):
|
||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
"""
|
||||
x = x.to(tl.int32, bitcast=True)
|
||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
||||
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
||||
scale = 4.6566127342e-10
|
||||
x = tl.where(x < 0, -x - 1, x)
|
||||
return x * max
|
||||
return x * scale
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
Reference in New Issue
Block a user