2016-05-29 09:07:09 -07:00
|
|
|
import hashlib
|
2022-01-24 23:22:11 +01:00
|
|
|
from typing import Optional, List, Tuple, Union, Any
|
2016-05-29 09:07:09 -07:00
|
|
|
import os
|
|
|
|
import struct
|
2021-12-08 22:14:15 +01:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from numpy.random import Generator
|
2016-05-29 09:07:09 -07:00
|
|
|
|
|
|
|
from gym import error
|
2021-12-08 22:14:15 +01:00
|
|
|
from gym.logger import deprecation
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
|
2020-04-10 17:10:34 -05:00
|
|
|
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
2021-11-14 14:50:53 +01:00
|
|
|
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
seed_seq = np.random.SeedSequence(seed)
|
2021-12-22 19:12:57 +01:00
|
|
|
np_seed = seed_seq.entropy
|
2021-12-08 22:14:15 +01:00
|
|
|
rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
|
2021-12-22 19:12:57 +01:00
|
|
|
return rng, np_seed
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
|
|
|
|
# RandomNumberGenerator = np.random.Generator
|
|
|
|
class RandomNumberGenerator(np.random.Generator):
|
|
|
|
def rand(self, *size):
|
|
|
|
deprecation(
|
|
|
|
"Function `rng.rand(*size)` is marked as deprecated "
|
|
|
|
"and will be removed in the future. "
|
|
|
|
"Please use `Generator.random(size)` instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.random(size)
|
|
|
|
|
|
|
|
random_sample = rand
|
|
|
|
|
|
|
|
def randn(self, *size):
|
|
|
|
deprecation(
|
|
|
|
"Function `rng.randn(*size)` is marked as deprecated "
|
|
|
|
"and will be removed in the future. "
|
|
|
|
"Please use `rng.standard_normal(size)` instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.standard_normal(size)
|
|
|
|
|
|
|
|
def randint(self, low, high=None, size=None, dtype=int):
|
|
|
|
deprecation(
|
|
|
|
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
|
|
|
|
"and will be removed in the future. "
|
|
|
|
"Please use `rng.integers(low, [high, size, dtype])` instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.integers(low=low, high=high, size=size, dtype=dtype)
|
|
|
|
|
|
|
|
random_integers = randint
|
|
|
|
|
|
|
|
def get_state(self):
|
|
|
|
deprecation(
|
|
|
|
"Function `rng.get_state()` is marked as deprecated "
|
|
|
|
"and will be removed in the future. "
|
|
|
|
"Please use `rng.bit_generator.state` instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.bit_generator.state
|
|
|
|
|
|
|
|
def set_state(self, state):
|
|
|
|
deprecation(
|
|
|
|
"Function `rng.set_state(state)` is marked as deprecated "
|
|
|
|
"and will be removed in the future. "
|
|
|
|
"Please use `rng.bit_generator.state = state` instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
self.bit_generator.state = state
|
|
|
|
|
|
|
|
def seed(self, seed=None):
|
|
|
|
deprecation(
|
|
|
|
"Function `rng.seed(seed)` is marked as deprecated "
|
|
|
|
"and will be removed in the future. "
|
|
|
|
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
self.bit_generator.state = type(self.bit_generator)(seed).state
|
|
|
|
|
|
|
|
rand.__doc__ = np.random.rand.__doc__
|
|
|
|
randn.__doc__ = np.random.randn.__doc__
|
|
|
|
randint.__doc__ = np.random.randint.__doc__
|
|
|
|
get_state.__doc__ = np.random.get_state.__doc__
|
|
|
|
set_state.__doc__ = np.random.set_state.__doc__
|
|
|
|
seed.__doc__ = np.random.seed.__doc__
|
|
|
|
|
|
|
|
|
|
|
|
RNG = RandomNumberGenerator
|
|
|
|
|
|
|
|
# Legacy functions
|
|
|
|
|
|
|
|
|
2021-12-22 19:12:57 +01:00
|
|
|
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
|
2016-05-29 09:07:09 -07:00
|
|
|
"""Any given evaluation is likely to have many PRNG's active at
|
|
|
|
once. (Most commonly, because the environment is running in
|
|
|
|
multiple processes.) There's literature indicating that having
|
|
|
|
linear correlations between seeds of multiple PRNG's can correlate
|
|
|
|
the outputs:
|
|
|
|
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
|
|
|
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
|
|
|
http://dl.acm.org/citation.cfm?id=1276928
|
|
|
|
Thus, for sanity we hash the seeds before using them. (This scheme
|
|
|
|
is likely not crypto-strength, but it should be good enough to get
|
|
|
|
rid of simple correlations.)
|
2016-05-30 18:07:59 -07:00
|
|
|
Args:
|
2021-12-22 19:12:57 +01:00
|
|
|
seed: None seeds from an operating system specific randomness source.
|
2016-05-30 18:07:59 -07:00
|
|
|
max_bytes: Maximum number of bytes to use in the hashed seed.
|
2016-05-29 09:07:09 -07:00
|
|
|
"""
|
2021-12-08 22:14:15 +01:00
|
|
|
deprecation(
|
|
|
|
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
|
|
|
|
)
|
2016-05-30 18:07:59 -07:00
|
|
|
if seed is None:
|
2018-02-03 01:29:30 -05:00
|
|
|
seed = create_seed(max_bytes=max_bytes)
|
2021-07-29 02:26:34 +02:00
|
|
|
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
|
2016-05-29 09:07:09 -07:00
|
|
|
return _bigint_from_bytes(hash[:max_bytes])
|
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-12-22 19:12:57 +01:00
|
|
|
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
|
2016-05-29 09:07:09 -07:00
|
|
|
"""Create a strong random seed. Otherwise, Python 2 would seed using
|
|
|
|
the system time, which might be non-robust especially in the
|
|
|
|
presence of concurrency.
|
|
|
|
Args:
|
2021-12-22 19:12:57 +01:00
|
|
|
a: None seeds from an operating system specific randomness source.
|
2016-05-30 18:07:59 -07:00
|
|
|
max_bytes: Maximum number of bytes to use in the seed.
|
2016-05-29 09:07:09 -07:00
|
|
|
"""
|
2021-12-08 22:14:15 +01:00
|
|
|
deprecation(
|
|
|
|
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
|
|
|
)
|
2016-05-29 09:07:09 -07:00
|
|
|
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
|
|
|
|
if a is None:
|
|
|
|
a = _bigint_from_bytes(os.urandom(max_bytes))
|
|
|
|
elif isinstance(a, str):
|
2021-12-22 19:12:57 +01:00
|
|
|
bt = a.encode("utf8")
|
|
|
|
bt += hashlib.sha512(bt).digest()
|
|
|
|
a = _bigint_from_bytes(bt[:max_bytes])
|
2020-04-10 17:10:34 -05:00
|
|
|
elif isinstance(a, int):
|
2021-12-22 19:12:57 +01:00
|
|
|
a = int(a % 2 ** (8 * max_bytes))
|
2016-05-29 09:07:09 -07:00
|
|
|
else:
|
2021-11-14 14:50:53 +01:00
|
|
|
raise error.Error(f"Invalid type for seed: {type(a)} ({a})")
|
2016-05-29 09:07:09 -07:00
|
|
|
|
|
|
|
return a
|
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2016-05-29 09:07:09 -07:00
|
|
|
# TODO: don't hardcode sizeof_int here
|
2021-12-22 19:12:57 +01:00
|
|
|
def _bigint_from_bytes(bt: bytes) -> int:
|
2021-12-08 22:14:15 +01:00
|
|
|
deprecation(
|
|
|
|
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
|
|
|
|
)
|
2016-05-29 09:07:09 -07:00
|
|
|
sizeof_int = 4
|
2021-12-22 19:12:57 +01:00
|
|
|
padding = sizeof_int - len(bt) % sizeof_int
|
|
|
|
bt += b"\0" * padding
|
|
|
|
int_count = int(len(bt) / sizeof_int)
|
|
|
|
unpacked = struct.unpack(f"{int_count}I", bt)
|
2016-05-29 09:07:09 -07:00
|
|
|
accum = 0
|
|
|
|
for i, val in enumerate(unpacked):
|
|
|
|
accum += 2 ** (sizeof_int * 8 * i) * val
|
|
|
|
return accum
|
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-12-22 19:12:57 +01:00
|
|
|
def _int_list_from_bigint(bigint: int) -> List[int]:
|
2021-12-08 22:14:15 +01:00
|
|
|
deprecation(
|
|
|
|
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
|
|
|
)
|
2016-05-29 09:07:09 -07:00
|
|
|
# Special case 0
|
|
|
|
if bigint < 0:
|
2021-11-14 14:50:53 +01:00
|
|
|
raise error.Error(f"Seed must be non-negative, not {bigint}")
|
2016-05-29 09:07:09 -07:00
|
|
|
elif bigint == 0:
|
|
|
|
return [0]
|
|
|
|
|
2021-12-22 19:12:57 +01:00
|
|
|
ints: List[int] = []
|
2016-05-29 09:07:09 -07:00
|
|
|
while bigint > 0:
|
|
|
|
bigint, mod = divmod(bigint, 2 ** 32)
|
|
|
|
ints.append(mod)
|
|
|
|
return ints
|