Files
Gymnasium/gym/utils/seeding.py
Andrea PIERRÉ e913bc81b8 Improve pre-commit workflow (#2602)
* feat: add `isort` to `pre-commit`

* ci: skip `__init__.py` file for `isort`

* ci: make `isort` mandatory in lint pipeline

* docs: add a section on Git hooks

* ci: check isort diff

* fix: isort from master branch

* docs: add pre-commit badge

* ci: update black + bandit versions

* feat: add PR template

* refactor: PR template

* ci: remove bandit

* docs: add Black badge

* ci: try to remove all `|| true` statements

* ci: remove lint_python job

- Remove `lint_python` CI job
- Move `pyupgrade` job to `pre-commit` workflow

* fix: avoid messing with typing

* docs: add a note on running `pre-cpmmit` manually

* ci: apply `pre-commit` to the whole codebase
2022-03-31 15:50:38 -04:00

204 lines
7.4 KiB
Python

import hashlib
import os
import struct
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from numpy.random import Generator
from gym import error
from gym.logger import deprecation
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
seed_seq = np.random.SeedSequence(seed)
np_seed = seed_seq.entropy
rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
return rng, np_seed
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
# RandomNumberGenerator = np.random.Generator
class RandomNumberGenerator(np.random.Generator):
def rand(self, *size):
deprecation(
"Function `rng.rand(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `Generator.random(size)` instead."
)
return self.random(size)
random_sample = rand
def randn(self, *size):
deprecation(
"Function `rng.randn(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.standard_normal(size)` instead."
)
return self.standard_normal(size)
def randint(self, low, high=None, size=None, dtype=int):
deprecation(
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.integers(low, [high, size, dtype])` instead."
)
return self.integers(low=low, high=high, size=size, dtype=dtype)
random_integers = randint
def get_state(self):
deprecation(
"Function `rng.get_state()` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state` instead."
)
return self.bit_generator.state
def set_state(self, state):
deprecation(
"Function `rng.set_state(state)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state = state` instead."
)
self.bit_generator.state = state
def seed(self, seed=None):
deprecation(
"Function `rng.seed(seed)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
)
self.bit_generator.state = type(self.bit_generator)(seed).state
rand.__doc__ = np.random.rand.__doc__
randn.__doc__ = np.random.randn.__doc__
randint.__doc__ = np.random.randint.__doc__
get_state.__doc__ = np.random.get_state.__doc__
set_state.__doc__ = np.random.set_state.__doc__
seed.__doc__ = np.random.seed.__doc__
def __reduce__(self):
# np.random.Generator defines __reduce__, but it's hard-coded to
# return a Generator instead of its subclass RandomNumberGenerator.
# We need to override it here, otherwise sampling from a Space will
# be broken after pickling and unpickling, due to using the deprecated
# methods defined above.
# See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223
# And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37
_, init_args, *args = np.random.Generator.__reduce__(self)
return (RandomNumberGenerator._generator_ctor, init_args, *args)
@staticmethod
def _generator_ctor(bit_generator_name="MT19937"):
# Workaround method for RandomNumberGenerator pickling, see __reduce__ above.
# Ported from numpy.random._pickle.__generator_ctor function.
from numpy.random._pickle import BitGenerators
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(
f"{bit_generator_name} is not a known BitGenerator module."
)
return RandomNumberGenerator(bit_generator())
RNG = RandomNumberGenerator
# Legacy functions
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
"""Any given evaluation is likely to have many PRNG's active at
once. (Most commonly, because the environment is running in
multiple processes.) There's literature indicating that having
linear correlations between seeds of multiple PRNG's can correlate
the outputs:
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928
Thus, for sanity we hash the seeds before using them. (This scheme
is likely not crypto-strength, but it should be good enough to get
rid of simple correlations.)
Args:
seed: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed.
"""
deprecation(
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
)
if seed is None:
seed = create_seed(max_bytes=max_bytes)
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
return _bigint_from_bytes(hash[:max_bytes])
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
"""Create a strong random seed. Otherwise, Python 2 would seed using
the system time, which might be non-robust especially in the
presence of concurrency.
Args:
a: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed.
"""
deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
)
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
if a is None:
a = _bigint_from_bytes(os.urandom(max_bytes))
elif isinstance(a, str):
bt = a.encode("utf8")
bt += hashlib.sha512(bt).digest()
a = _bigint_from_bytes(bt[:max_bytes])
elif isinstance(a, int):
a = int(a % 2 ** (8 * max_bytes))
else:
raise error.Error(f"Invalid type for seed: {type(a)} ({a})")
return a
# TODO: don't hardcode sizeof_int here
def _bigint_from_bytes(bt: bytes) -> int:
deprecation(
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
)
sizeof_int = 4
padding = sizeof_int - len(bt) % sizeof_int
bt += b"\0" * padding
int_count = int(len(bt) / sizeof_int)
unpacked = struct.unpack(f"{int_count}I", bt)
accum = 0
for i, val in enumerate(unpacked):
accum += 2 ** (sizeof_int * 8 * i) * val
return accum
def _int_list_from_bigint(bigint: int) -> List[int]:
deprecation(
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
)
# Special case 0
if bigint < 0:
raise error.Error(f"Seed must be non-negative, not {bigint}")
elif bigint == 0:
return [0]
ints: List[int] = []
while bigint > 0:
bigint, mod = divmod(bigint, 2**32)
ints.append(mod)
return ints