[LANG] Added seeded random number generation - philox (#261)
This commit is contained in:
198
python/test/language/test_random.py
Normal file
198
python/test/language/test_random.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
|
||||
from numpy.random import Philox
|
||||
|
||||
#####################################
|
||||
## Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
|
||||
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
|
||||
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
|
||||
self.DTYPE = DTYPE
|
||||
|
||||
|
||||
# This is better for GPU
|
||||
PHILOX_32 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B9,
|
||||
PHILOX_KEY_B=0xBB67AE85,
|
||||
PHILOX_ROUND_A=0xD2511F53,
|
||||
PHILOX_ROUND_B=0xCD9E8D57,
|
||||
DTYPE=np.uint32,
|
||||
)
|
||||
|
||||
# This is what numpy implements
|
||||
PHILOX_64 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B97F4A7C15,
|
||||
PHILOX_KEY_B=0xBB67AE8584CAA73B,
|
||||
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
|
||||
PHILOX_ROUND_B=0xCA5A826395121157,
|
||||
DTYPE=np.uint64,
|
||||
)
|
||||
|
||||
|
||||
class CustomPhilox4x:
|
||||
def __init__(self, seed, config):
|
||||
self._config = config
|
||||
seed = self._into_pieces(seed)
|
||||
self._key = np.array(seed[:2], dtype=self._dtype)
|
||||
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
|
||||
|
||||
@property
|
||||
def _dtype(self):
|
||||
return self._config.DTYPE
|
||||
|
||||
def _into_pieces(self, n, pad=4):
|
||||
res = []
|
||||
while len(res) < pad:
|
||||
res.append(np.array(n, dtype=self._dtype))
|
||||
n >>= (np.dtype(self._dtype).itemsize * 8)
|
||||
assert n == 0
|
||||
return tuple(res)
|
||||
|
||||
def _multiply_low_high(self, a, b):
|
||||
low = a * b
|
||||
high = int(a) * int(b)
|
||||
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
|
||||
return low, high
|
||||
|
||||
def _single_round(self, counter, key):
|
||||
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
|
||||
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
|
||||
ret0 = hi1 ^ counter[1] ^ key[0]
|
||||
ret1 = lo1
|
||||
ret2 = hi0 ^ counter[3] ^ key[1]
|
||||
ret3 = lo0
|
||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||
|
||||
def _raise_key(self, key):
|
||||
ret0 = key[0] + self._config.PHILOX_KEY_A
|
||||
ret1 = key[1] + self._config.PHILOX_KEY_B
|
||||
return np.array([ret0, ret1], dtype=self._dtype)
|
||||
|
||||
def random_raw(self):
|
||||
counter = self._counter
|
||||
key = self._key
|
||||
for _ in range(10):
|
||||
counter = self._single_round(counter, key)
|
||||
key = self._raise_key(key)
|
||||
self.advance(1)
|
||||
return counter
|
||||
|
||||
def advance(self, n_steps):
|
||||
self._counter[0] += n_steps
|
||||
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
|
||||
|
||||
|
||||
class CustomPhilox(CustomPhilox4x):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.buffer = []
|
||||
|
||||
def random_raw(self):
|
||||
if len(self.buffer) == 0:
|
||||
self.buffer = list(super().random_raw())[::-1]
|
||||
return int(self.buffer.pop())
|
||||
|
||||
|
||||
#####################################
|
||||
## Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randint(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.int32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
|
||||
# reference result
|
||||
gen = CustomPhilox4x(seed, config=PHILOX_32)
|
||||
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test conversion of random uint32 into random float in [0, 1]
|
||||
def test_uint32_to_uniform_float():
|
||||
@triton.jit
|
||||
def kernel(SRC, TGT, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
offset = pid * BLOCK + tl.arange(0, BLOCK)
|
||||
src = tl.load(SRC + offset)
|
||||
tgt = tl.random.uint32_to_uniform_float(src)
|
||||
tl.store(TGT + offset, tgt, mask=offset < N)
|
||||
|
||||
def run(source):
|
||||
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
|
||||
N = source.numel()
|
||||
grid = lambda meta: (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](source, target, N)
|
||||
return target
|
||||
|
||||
# check range of edge values
|
||||
n = 100
|
||||
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
|
||||
target = run(source).tolist()
|
||||
assert target == sorted(target)
|
||||
assert all(0.0 <= num < 1.0 for num in target)
|
||||
# check distribution is uniform
|
||||
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
|
||||
target = run(source).tolist()
|
||||
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test uniform PRNG
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.rand(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randn(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
Reference in New Issue
Block a user