removed inline RandomNumberGenerator implementation (#3022)

This commit is contained in:
John Balis
2022-08-22 09:20:28 -04:00
committed by GitHub
parent 78c3faac04
commit 7b0b85cbca
12 changed files with 22 additions and 232 deletions

View File

@@ -13,10 +13,11 @@ from typing import (
Union,
)
import numpy as np
from gym import spaces
from gym.logger import deprecation, warn
from gym.utils import seeding
from gym.utils.seeding import RandomNumberGenerator
if TYPE_CHECKING:
from gym.envs.registration import EnvSpec
@@ -105,17 +106,17 @@ class Env(Generic[ObsType, ActType]):
observation_space: spaces.Space[ObsType]
# Created
_np_random: Optional[RandomNumberGenerator] = None
_np_random: Optional[np.random.Generator] = None
@property
def np_random(self) -> RandomNumberGenerator:
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed."""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
return self._np_random
@np_random.setter
def np_random(self, value: RandomNumberGenerator):
def np_random(self, value: np.random.Generator):
self._np_random = value
def step(
@@ -384,7 +385,7 @@ class Wrapper(Env[ObsType, ActType]):
return self.env.render_mode
@property
def np_random(self) -> RandomNumberGenerator:
def np_random(self) -> np.random.Generator:
"""Returns the environment np_random."""
return self.env.np_random

View File

@@ -1,9 +1,7 @@
import numpy as np
from gym.utils import seeding
def categorical_sample(prob_n, np_random: seeding.RandomNumberGenerator):
def categorical_sample(prob_n, np_random: np.random.Generator):
"""Sample from categorical distribution where each row specifies class probabilities."""
prob_n = np.asarray(prob_n)
csprob_n = np.cumsum(prob_n)

View File

@@ -6,7 +6,6 @@ import numpy as np
import gym.error
from gym import logger
from gym.spaces.space import Space
from gym.utils import seeding
def _short_repr(arr: np.ndarray) -> str:
@@ -57,7 +56,7 @@ class Box(Space[np.ndarray]):
high: Union[SupportsFloat, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: Type = np.float32,
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
):
r"""Constructor of :class:`Box`.

View File

@@ -8,7 +8,6 @@ from typing import Optional, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class Dict(Space[TypingDict[str, Space]], Mapping):
@@ -53,7 +52,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def __init__(
self,
spaces: Optional[TypingDict[str, Space]] = None,
seed: Optional[Union[dict, int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[dict, int, np.random.Generator]] = None,
**spaces_kwargs: Space,
):
"""Constructor of :class:`Dict` space.

View File

@@ -4,7 +4,6 @@ from typing import Optional, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class Discrete(Space[int]):
@@ -21,7 +20,7 @@ class Discrete(Space[int]):
def __init__(
self,
n: int,
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
start: int = 0,
):
r"""Constructor of :class:`Discrete` space.

View File

@@ -8,7 +8,6 @@ from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import SAMPLE_MASK_TYPE, MultiDiscrete
from gym.spaces.space import Space
from gym.utils import seeding
class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"])):
@@ -40,7 +39,7 @@ class Graph(Space):
self,
node_space: Union[Box, Discrete],
edge_space: Union[None, Box, Discrete],
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
):
r"""Constructor of :class:`Graph`.

View File

@@ -4,7 +4,6 @@ from typing import Optional, Sequence, Tuple, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class MultiBinary(Space[np.ndarray]):
@@ -27,7 +26,7 @@ class MultiBinary(Space[np.ndarray]):
def __init__(
self,
n: Union[np.ndarray, Sequence[int], int],
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
):
"""Constructor of :class:`MultiBinary` space.

View File

@@ -6,7 +6,6 @@ import numpy as np
from gym import logger
from gym.spaces.discrete import Discrete
from gym.spaces.space import Space
from gym.utils import seeding
SAMPLE_MASK_TYPE = Tuple[Union["SAMPLE_MASK_TYPE", np.ndarray], ...]
@@ -42,7 +41,7 @@ class MultiDiscrete(Space[np.ndarray]):
self,
nvec: Union[np.ndarray, List[int]],
dtype=np.int64,
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
):
"""Constructor of :class:`MultiDiscrete` space.

View File

@@ -51,7 +51,7 @@ class Space(Generic[T_cov]):
self,
shape: Optional[Sequence[int]] = None,
dtype: Optional[Union[Type, str, np.dtype]] = None,
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
):
"""Constructor of :class:`Space`.
@@ -64,13 +64,13 @@ class Space(Generic[T_cov]):
self.dtype = None if dtype is None else np.dtype(dtype)
self._np_random = None
if seed is not None:
if isinstance(seed, seeding.RandomNumberGenerator):
if isinstance(seed, np.random.Generator):
self._np_random = seed
else:
self.seed(seed)
@property
def np_random(self) -> seeding.RandomNumberGenerator:
def np_random(self) -> np.random.Generator:
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space."""
if self._np_random is None:
self.seed()

View File

@@ -4,7 +4,6 @@ from typing import Any, FrozenSet, List, Optional, Set, Tuple, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
alphanumeric: FrozenSet[str] = frozenset(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
@@ -30,7 +29,7 @@ class Text(Space[str]):
*,
min_length: int = 0,
charset: Union[Set[str], str] = alphanumeric,
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
):
r"""Constructor of :class:`Text` space.
@@ -90,7 +89,7 @@ class Text(Space[str]):
length, charlist_mask = None, None
if length is None:
length = self.np_random.randint(self.min_length, self.max_length + 1)
length = self.np_random.integers(self.min_length, self.max_length + 1)
if charlist_mask is None:
string = self.np_random.choice(self._charlist, size=length)

View File

@@ -4,7 +4,6 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class Tuple(Space[tuple], Sequence):
@@ -23,7 +22,7 @@ class Tuple(Space[tuple], Sequence):
def __init__(
self,
spaces: Iterable[Space],
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, List[int], np.random.Generator]] = None,
):
r"""Constructor of :class:`Tuple` space.

View File

@@ -1,16 +1,12 @@
"""Set of random number generator functions: seeding, generator, hashing seeds."""
import hashlib
import os
import struct
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple
import numpy as np
from gym import error
from gym.logger import deprecation
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]:
"""Generates a random number generator from the seed and returns the Generator and seed.
Args:
@@ -31,201 +27,4 @@ def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]
return rng, np_seed
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
# RandomNumberGenerator = np.random.Generator
class RandomNumberGenerator(np.random.Generator):
"""Random number generator class that inherits from numpy's random Generator class."""
def rand(self, *size):
"""Deprecated rand function using random."""
deprecation(
"Function `rng.rand(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `Generator.random(size)` instead."
)
return self.random(size)
random_sample = rand
def randn(self, *size):
"""Deprecated random standard normal function use standard_normal."""
deprecation(
"Function `rng.randn(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.standard_normal(size)` instead."
)
return self.standard_normal(size)
def randint(self, low, high=None, size=None, dtype=int):
"""Deprecated random integer function use integers."""
deprecation(
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.integers(low, [high, size, dtype])` instead."
)
return self.integers(low=low, high=high, size=size, dtype=dtype)
random_integers = randint
def get_state(self):
"""Deprecated get rng state use bit_generator.state."""
deprecation(
"Function `rng.get_state()` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state` instead."
)
return self.bit_generator.state
def set_state(self, state):
"""Deprecated set rng state function use bit_generator.state = state."""
deprecation(
"Function `rng.set_state(state)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state = state` instead."
)
self.bit_generator.state = state
def seed(self, seed=None):
"""Deprecated seed function use gym.utils.seeding.np_random(seed)."""
deprecation(
"Function `rng.seed(seed)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
)
self.bit_generator.state = type(self.bit_generator)(seed).state
rand.__doc__ = np.random.rand.__doc__
randn.__doc__ = np.random.randn.__doc__
randint.__doc__ = np.random.randint.__doc__
get_state.__doc__ = np.random.get_state.__doc__
set_state.__doc__ = np.random.set_state.__doc__
seed.__doc__ = np.random.seed.__doc__
def __reduce__(self):
"""Reduces the Random Number Generator to a RandomNumberGenerator, init_args and additional args."""
# np.random.Generator defines __reduce__, but it's hard-coded to
# return a Generator instead of its subclass RandomNumberGenerator.
# We need to override it here, otherwise sampling from a Space will
# be broken after pickling and unpickling, due to using the deprecated
# methods defined above.
# See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223
# And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37
_, init_args, *args = np.random.Generator.__reduce__(self)
return (RandomNumberGenerator._generator_ctor, init_args, *args)
@staticmethod
def _generator_ctor(bit_generator_name="MT19937"):
# Workaround method for RandomNumberGenerator pickling, see __reduce__ above.
# Ported from numpy.random._pickle.__generator_ctor function.
from numpy.random._pickle import BitGenerators
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(
f"{bit_generator_name} is not a known BitGenerator module."
)
return RandomNumberGenerator(bit_generator())
RNG = RandomNumberGenerator
# Legacy functions
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
"""Any given evaluation is likely to have many PRNG's active at once.
(Most commonly, because the environment is running in multiple processes.)
There's literature indicating that having linear correlations between seeds of multiple PRNG's can correlate the outputs:
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928
Thus, for sanity we hash the seeds before using them. (This scheme is likely not crypto-strength, but it should be good enough to get rid of simple correlations.)
Args:
seed: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed.
Returns:
The hashed seed
"""
deprecation(
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
)
if seed is None:
seed = create_seed(max_bytes=max_bytes)
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
return _bigint_from_bytes(hash[:max_bytes])
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
"""Create a strong random seed.
Otherwise, Python 2 would seed using the system time, which might be non-robust especially in the presence of concurrency.
Args:
a: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed.
Returns:
A seed
Raises:
Error: Invalid type for seed, expects None or str or int
"""
deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
)
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
if a is None:
a = _bigint_from_bytes(os.urandom(max_bytes))
elif isinstance(a, str):
bt = a.encode("utf8")
bt += hashlib.sha512(bt).digest()
a = _bigint_from_bytes(bt[:max_bytes])
elif isinstance(a, int):
a = int(a % 2 ** (8 * max_bytes))
else:
raise error.Error(f"Invalid type for seed: {type(a)} ({a})")
return a
# TODO: don't hardcode sizeof_int here
def _bigint_from_bytes(bt: bytes) -> int:
deprecation(
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
)
sizeof_int = 4
padding = sizeof_int - len(bt) % sizeof_int
bt += b"\0" * padding
int_count = int(len(bt) / sizeof_int)
unpacked = struct.unpack(f"{int_count}I", bt)
accum = 0
for i, val in enumerate(unpacked):
accum += 2 ** (sizeof_int * 8 * i) * val
return accum
def _int_list_from_bigint(bigint: int) -> List[int]:
deprecation(
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
)
# Special case 0
if bigint < 0:
raise error.Error(f"Seed must be non-negative, not {bigint}")
elif bigint == 0:
return [0]
ints: List[int] = []
while bigint > 0:
bigint, mod = divmod(bigint, 2**32)
ints.append(mod)
return ints
RNG = RandomNumberGenerator = np.random.Generator