[FRONTEND] Make tl.rand() 1-exclusive (#601)
This commit is contained in:
@@ -175,3 +175,24 @@ def test_randn(size, seed, device='cuda'):
|
|||||||
kernel[grid](x, N, seed)
|
kernel[grid](x, N, seed)
|
||||||
assert abs(x.mean()) < 1e-2
|
assert abs(x.mean()) < 1e-2
|
||||||
assert abs(x.std() - 1) < 1e-2
|
assert abs(x.std() - 1) < 1e-2
|
||||||
|
|
||||||
|
|
||||||
|
# tl.rand() should never produce >=1.0
|
||||||
|
|
||||||
|
def test_rand_limits():
|
||||||
|
@triton.jit
|
||||||
|
def kernel(input, output, n: tl.constexpr):
|
||||||
|
idx = tl.arange(0, n)
|
||||||
|
x = tl.load(input + idx)
|
||||||
|
y = tl.random.uint32_to_uniform_float(x)
|
||||||
|
tl.store(output + idx, y)
|
||||||
|
|
||||||
|
min_max_int32 = torch.tensor([
|
||||||
|
torch.iinfo(torch.int32).min,
|
||||||
|
torch.iinfo(torch.int32).max,
|
||||||
|
], dtype=torch.int32, device='cuda')
|
||||||
|
output = torch.empty(2, dtype=torch.float32, device='cuda')
|
||||||
|
kernel[(1,)](min_max_int32, output, 2)
|
||||||
|
|
||||||
|
assert output[0] == output[1]
|
||||||
|
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0
|
||||||
|
@@ -91,9 +91,10 @@ def uint32_to_uniform_float(x):
|
|||||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||||
"""
|
"""
|
||||||
x = x.to(tl.int32, bitcast=True)
|
x = x.to(tl.int32, bitcast=True)
|
||||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
||||||
|
scale = 4.6566127342e-10
|
||||||
x = tl.where(x < 0, -x - 1, x)
|
x = tl.where(x < 0, -x - 1, x)
|
||||||
return x * max
|
return x * scale
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
Reference in New Issue
Block a user