mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
* remove six * remove __future__ imports * remove six from setup.py, python 2.7 from README.rst
91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
import hashlib
|
|
import numpy as np
|
|
import os
|
|
import random as _random
|
|
import struct
|
|
import sys
|
|
|
|
from gym import error
|
|
|
|
def np_random(seed=None):
|
|
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
|
raise error.Error('Seed must be a non-negative integer or omitted, not {}'.format(seed))
|
|
|
|
seed = create_seed(seed)
|
|
|
|
rng = np.random.RandomState()
|
|
rng.seed(_int_list_from_bigint(hash_seed(seed)))
|
|
return rng, seed
|
|
|
|
def hash_seed(seed=None, max_bytes=8):
|
|
"""Any given evaluation is likely to have many PRNG's active at
|
|
once. (Most commonly, because the environment is running in
|
|
multiple processes.) There's literature indicating that having
|
|
linear correlations between seeds of multiple PRNG's can correlate
|
|
the outputs:
|
|
|
|
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
|
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
|
http://dl.acm.org/citation.cfm?id=1276928
|
|
|
|
Thus, for sanity we hash the seeds before using them. (This scheme
|
|
is likely not crypto-strength, but it should be good enough to get
|
|
rid of simple correlations.)
|
|
|
|
Args:
|
|
seed (Optional[int]): None seeds from an operating system specific randomness source.
|
|
max_bytes: Maximum number of bytes to use in the hashed seed.
|
|
"""
|
|
if seed is None:
|
|
seed = create_seed(max_bytes=max_bytes)
|
|
hash = hashlib.sha512(str(seed).encode('utf8')).digest()
|
|
return _bigint_from_bytes(hash[:max_bytes])
|
|
|
|
def create_seed(a=None, max_bytes=8):
|
|
"""Create a strong random seed. Otherwise, Python 2 would seed using
|
|
the system time, which might be non-robust especially in the
|
|
presence of concurrency.
|
|
|
|
Args:
|
|
a (Optional[int, str]): None seeds from an operating system specific randomness source.
|
|
max_bytes: Maximum number of bytes to use in the seed.
|
|
"""
|
|
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
|
|
if a is None:
|
|
a = _bigint_from_bytes(os.urandom(max_bytes))
|
|
elif isinstance(a, str):
|
|
a = a.encode('utf8')
|
|
a += hashlib.sha512(a).digest()
|
|
a = _bigint_from_bytes(a[:max_bytes])
|
|
elif isinstance(a, int):
|
|
a = a % 2**(8 * max_bytes)
|
|
else:
|
|
raise error.Error('Invalid type for seed: {} ({})'.format(type(a), a))
|
|
|
|
return a
|
|
|
|
# TODO: don't hardcode sizeof_int here
|
|
def _bigint_from_bytes(bytes):
|
|
sizeof_int = 4
|
|
padding = sizeof_int - len(bytes) % sizeof_int
|
|
bytes += b'\0' * padding
|
|
int_count = int(len(bytes) / sizeof_int)
|
|
unpacked = struct.unpack("{}I".format(int_count), bytes)
|
|
accum = 0
|
|
for i, val in enumerate(unpacked):
|
|
accum += 2 ** (sizeof_int * 8 * i) * val
|
|
return accum
|
|
|
|
def _int_list_from_bigint(bigint):
|
|
# Special case 0
|
|
if bigint < 0:
|
|
raise error.Error('Seed must be non-negative, not {}'.format(bigint))
|
|
elif bigint == 0:
|
|
return [0]
|
|
|
|
ints = []
|
|
while bigint > 0:
|
|
bigint, mod = divmod(bigint, 2 ** 32)
|
|
ints.append(mod)
|
|
return ints
|