Files
Gymnasium/gym/utils/seeding.py
Ilya Kamen ad79b0ad0f typing in gym.spaces (#2541)
* typing in spaces.Box and spaces.Discrete

* adds typing to dict and tuple spaces

* Typecheck all spaces

* Explicit regex to include all files under space folder

* Style: use native types and __future__ annotations

* Allow only specific strings for Box.is_bounded args

* Add typing to changes from #2517

* Remove Literal as it's not supported by py3.7

* Use more recent version of pyright

* Avoid name clash for type checker

* Revert "Avoid name clash for type checker"

This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1.

* Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852

* rebase and add typing for `_short_repr`
2022-01-24 17:22:11 -05:00

179 lines
6.1 KiB
Python

import hashlib
from typing import Optional, List, Tuple, Union, Any
import os
import struct
import numpy as np
from numpy.random import Generator
from gym import error
from gym.logger import deprecation
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
seed_seq = np.random.SeedSequence(seed)
np_seed = seed_seq.entropy
rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
return rng, np_seed
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
# RandomNumberGenerator = np.random.Generator
class RandomNumberGenerator(np.random.Generator):
def rand(self, *size):
deprecation(
"Function `rng.rand(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `Generator.random(size)` instead."
)
return self.random(size)
random_sample = rand
def randn(self, *size):
deprecation(
"Function `rng.randn(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.standard_normal(size)` instead."
)
return self.standard_normal(size)
def randint(self, low, high=None, size=None, dtype=int):
deprecation(
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.integers(low, [high, size, dtype])` instead."
)
return self.integers(low=low, high=high, size=size, dtype=dtype)
random_integers = randint
def get_state(self):
deprecation(
"Function `rng.get_state()` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state` instead."
)
return self.bit_generator.state
def set_state(self, state):
deprecation(
"Function `rng.set_state(state)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state = state` instead."
)
self.bit_generator.state = state
def seed(self, seed=None):
deprecation(
"Function `rng.seed(seed)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
)
self.bit_generator.state = type(self.bit_generator)(seed).state
rand.__doc__ = np.random.rand.__doc__
randn.__doc__ = np.random.randn.__doc__
randint.__doc__ = np.random.randint.__doc__
get_state.__doc__ = np.random.get_state.__doc__
set_state.__doc__ = np.random.set_state.__doc__
seed.__doc__ = np.random.seed.__doc__
RNG = RandomNumberGenerator
# Legacy functions
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
"""Any given evaluation is likely to have many PRNG's active at
once. (Most commonly, because the environment is running in
multiple processes.) There's literature indicating that having
linear correlations between seeds of multiple PRNG's can correlate
the outputs:
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928
Thus, for sanity we hash the seeds before using them. (This scheme
is likely not crypto-strength, but it should be good enough to get
rid of simple correlations.)
Args:
seed: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed.
"""
deprecation(
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
)
if seed is None:
seed = create_seed(max_bytes=max_bytes)
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
return _bigint_from_bytes(hash[:max_bytes])
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
"""Create a strong random seed. Otherwise, Python 2 would seed using
the system time, which might be non-robust especially in the
presence of concurrency.
Args:
a: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed.
"""
deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
)
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
if a is None:
a = _bigint_from_bytes(os.urandom(max_bytes))
elif isinstance(a, str):
bt = a.encode("utf8")
bt += hashlib.sha512(bt).digest()
a = _bigint_from_bytes(bt[:max_bytes])
elif isinstance(a, int):
a = int(a % 2 ** (8 * max_bytes))
else:
raise error.Error(f"Invalid type for seed: {type(a)} ({a})")
return a
# TODO: don't hardcode sizeof_int here
def _bigint_from_bytes(bt: bytes) -> int:
deprecation(
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
)
sizeof_int = 4
padding = sizeof_int - len(bt) % sizeof_int
bt += b"\0" * padding
int_count = int(len(bt) / sizeof_int)
unpacked = struct.unpack(f"{int_count}I", bt)
accum = 0
for i, val in enumerate(unpacked):
accum += 2 ** (sizeof_int * 8 * i) * val
return accum
def _int_list_from_bigint(bigint: int) -> List[int]:
deprecation(
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
)
# Special case 0
if bigint < 0:
raise error.Error(f"Seed must be non-negative, not {bigint}")
elif bigint == 0:
return [0]
ints: List[int] = []
while bigint > 0:
bigint, mod = divmod(bigint, 2 ** 32)
ints.append(mod)
return ints