[FRONTEND] Make tl.rand() 1-exclusive (#601)

This commit is contained in:
Jason Ansel
2022-07-24 17:47:23 -07:00
committed by GitHub
parent e02e56dc63
commit 027321cdcf
2 changed files with 24 additions and 2 deletions

View File

@@ -175,3 +175,24 @@ def test_randn(size, seed, device='cuda'):
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2
# tl.rand() should never produce >=1.0
def test_rand_limits():
@triton.jit
def kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = tl.random.uint32_to_uniform_float(x)
tl.store(output + idx, y)
min_max_int32 = torch.tensor([
torch.iinfo(torch.int32).min,
torch.iinfo(torch.int32).max,
], dtype=torch.int32, device='cuda')
output = torch.empty(2, dtype=torch.float32, device='cuda')
kernel[(1,)](min_max_int32, output, 2)
assert output[0] == output[1]
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0

View File

@@ -91,9 +91,10 @@ def uint32_to_uniform_float(x):
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
"""
x = x.to(tl.int32, bitcast=True)
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
scale = 4.6566127342e-10
x = tl.where(x < 0, -x - 1, x)
return x * max
return x * scale
@triton.jit