Made readout of seed possible in env (#889)

This commit is contained in:
Michael Panchenko
2024-02-26 13:00:18 +01:00
committed by GitHub
parent 9e83d5442c
commit b3f0361f91
11 changed files with 213 additions and 21 deletions

View File

@@ -54,6 +54,7 @@ title: Env
.. autoproperty:: gymnasium.Env.unwrapped
.. autoproperty:: gymnasium.Env.np_random
.. autoproperty:: gymnasium.Env.np_random_seed
```
## Implementing environments

View File

@@ -67,6 +67,7 @@ vector/utils
```{eval-rst}
.. autoproperty:: gymnasium.vector.VectorEnv.unwrapped
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
.. autoproperty:: gymnasium.vector.VectorEnv.np_random_seed
```
## Making Vector Environments

View File

@@ -11,3 +11,10 @@
.. automethod:: gymnasium.vector.AsyncVectorEnv.get_attr
.. automethod:: gymnasium.vector.AsyncVectorEnv.set_attr
```
### Additional Methods
```{eval-rst}
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
.. autoproperty:: gymnasium.vector.VectorEnv.np_random_seed
```

View File

@@ -11,3 +11,10 @@
.. automethod:: gymnasium.vector.SyncVectorEnv.get_attr
.. automethod:: gymnasium.vector.SyncVectorEnv.set_attr
```
### Additional Methods
```{eval-rst}
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
.. autoproperty:: gymnasium.vector.VectorEnv.np_random_seed
```

View File

@@ -47,5 +47,6 @@ wrappers/reward_wrappers
.. autoproperty:: gymnasium.Wrapper.spec
.. autoproperty:: gymnasium.Wrapper.metadata
.. autoproperty:: gymnasium.Wrapper.np_random
.. autoproperty:: gymnasium.Wrapper.np_random_seed
.. autoproperty:: gymnasium.Wrapper.unwrapped
```

View File

@@ -66,6 +66,8 @@ class Env(Generic[ObsType, ActType]):
# Created
_np_random: np.random.Generator | None = None
# will be set to the "invalid" value -1 if the seed of the currently set rng is unknown
_np_random_seed: int | None = None
def step(
self, action: ActType
@@ -130,10 +132,12 @@ class Env(Generic[ObsType, ActType]):
The ``return_info`` parameter was removed and now info is expected to be returned.
Args:
seed (optional int): The seed that is used to initialize the environment's PRNG (`np_random`).
seed (optional int): The seed that is used to initialize the environment's PRNG (`np_random`) and
the read-only attribute `np_random_seed`.
If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed,
a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset.
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset
and the env's :attr:`np_random_seed` will *not* be altered.
If you pass an integer, the PRNG will be reset even if it already exists.
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
Please refer to the minimal example above to see this paradigm in action.
@@ -148,7 +152,7 @@ class Env(Generic[ObsType, ActType]):
"""
# Initialize the RNG if the seed is manually passed
if seed is not None:
self._np_random, seed = seeding.np_random(seed)
self._np_random, self._np_random_seed = seeding.np_random(seed)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment.
@@ -201,6 +205,20 @@ class Env(Generic[ObsType, ActType]):
"""
return self
@property
def np_random_seed(self) -> int:
"""Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
the seed will take the value -1.
Returns:
int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
"""
if self._np_random_seed is None:
self._np_random, self._np_random_seed = seeding.np_random()
return self._np_random_seed
@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
@@ -209,12 +227,20 @@ class Env(Generic[ObsType, ActType]):
Instances of `np.random.Generator`
"""
if self._np_random is None:
self._np_random, _ = seeding.np_random()
self._np_random, self._np_random_seed = seeding.np_random()
return self._np_random
@np_random.setter
def np_random(self, value: np.random.Generator):
"""Sets the environment's internal :attr:`_np_random` with the user-provided Generator.
Since it is generally not possible to extract a seed from an instance of a random number generator,
this will also set the :attr:`_np_random_seed` to `-1`, which is not valid as input for the creation
of a numpy rng.
"""
self._np_random = value
# Setting a numpy rng with -1 will cause a ValueError
self._np_random_seed = -1
def __str__(self):
"""Returns a string of the environment with :attr:`spec` id's if :attr:`spec.
@@ -303,6 +329,11 @@ class Wrapper(
"""Closes the wrapper and :attr:`env`."""
return self.env.close()
@property
def np_random_seed(self) -> int | None:
"""Returns the base enviroment's :attr:`np_random_seed`."""
return self.env.np_random_seed
@property
def unwrapped(self) -> Env[ObsType, ActType]:
"""Returns the base environment of the wrapper.

View File

@@ -197,6 +197,16 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT
self._check_spaces()
@property
def np_random_seed(self) -> tuple[int, ...]:
"""Returns the seeds of the wrapped envs."""
return self.get_attr("np_random_seed")
@property
def np_random(self) -> tuple[np.random.Generator, ...]:
"""Returns the numpy random number generators of the wrapped envs."""
return self.get_attr("np_random")
def reset(
self,
*,
@@ -240,7 +250,9 @@ class AsyncVectorEnv(VectorEnv):
seed = [None for _ in range(self.num_envs)]
elif isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
assert (
len(seed) == self.num_envs
), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
@@ -472,7 +484,7 @@ class AsyncVectorEnv(VectorEnv):
return results
def get_attr(self, name: str):
def get_attr(self, name: str) -> tuple[Any, ...]:
"""Get a property from each parallel environment.
Args:

View File

@@ -100,6 +100,16 @@ class SyncVectorEnv(VectorEnv):
self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)
@property
def np_random_seed(self) -> tuple[int, ...]:
"""Returns the seeds of the wrapped envs."""
return self.get_attr("np_random_seed")
@property
def np_random(self) -> tuple[np.random.Generator, ...]:
"""Returns the numpy random number generators of the wrapped envs."""
return self.get_attr("np_random")
def reset(
self,
*,
@@ -122,7 +132,9 @@ class SyncVectorEnv(VectorEnv):
seed = [None for _ in range(self.num_envs)]
elif isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
assert (
len(seed) == self.num_envs
), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
@@ -211,7 +223,7 @@ class SyncVectorEnv(VectorEnv):
return tuple(results)
def get_attr(self, name: str) -> Any:
def get_attr(self, name: str) -> tuple[Any, ...]:
"""Get a property from each parallel environment.
Args:

View File

@@ -104,17 +104,18 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
num_envs: int
_np_random: np.random.Generator | None = None
_np_random_seed: int | None = None
def reset(
self,
*,
seed: int | list[int] | None = None,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
"""Reset all parallel environments and return a batch of initial observations and info.
Args:
seed: The environment reset seeds
seed: The environment reset seed
options: If to return the options
Returns:
@@ -133,7 +134,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
{}
"""
if seed is not None:
self._np_random, seed = seeding.np_random(seed)
self._np_random, self._np_random_seed = seeding.np_random(seed)
def step(
self, actions: ActType
@@ -210,6 +211,20 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
@property
def np_random_seed(self) -> int | None:
"""Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
the seed will take the value -1.
Returns:
int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
"""
if self._np_random_seed is None:
self._np_random, self._np_random_seed = seeding.np_random()
return self._np_random_seed
@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
@@ -218,12 +233,13 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
Instances of `np.random.Generator`
"""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
self._np_random, self._np_random_seed = seeding.np_random()
return self._np_random
@np_random.setter
def np_random(self, value: np.random.Generator):
self._np_random = value
self._np_random_seed = -1
@property
def unwrapped(self):
@@ -430,6 +446,19 @@ class VectorWrapper(VectorEnv):
"""Returns the `render_mode` from the base environment."""
return self.env.render_mode
@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
Returns:
Instances of `np.random.Generator`
"""
return self.env.np_random
@np_random.setter
def np_random(self, value: np.random.Generator):
self.env.np_random = value
class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation.

View File

@@ -12,6 +12,7 @@ from gymnasium import ActionWrapper, Env, ObservationWrapper, RewardWrapper, Wra
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.spaces import Box
from gymnasium.utils import seeding
from gymnasium.utils.seeding import np_random
from gymnasium.wrappers import OrderEnforcing
from tests.testing_env import GenericTestEnv
@@ -37,17 +38,22 @@ class ExampleEnv(Env):
options: dict | None = None,
) -> tuple[ObsType, dict]:
"""Resets the environment."""
super().reset(seed=seed, options=options)
return 0, {}
def test_example_env():
"""Tests a gymnasium environment."""
env = ExampleEnv()
@pytest.fixture
def example_env():
return ExampleEnv()
assert env.metadata == {"render_modes": []}
assert env.render_mode is None
assert env.spec is None
assert env._np_random is None # pyright: ignore [reportPrivateUsage]
def test_example_env(example_env):
"""Tests a gymnasium environment."""
assert example_env.metadata == {"render_modes": []}
assert example_env.render_mode is None
assert example_env.spec is None
assert example_env._np_random is None # pyright: ignore [reportPrivateUsage]
class ExampleWrapper(Wrapper):
@@ -77,9 +83,9 @@ class ExampleWrapper(Wrapper):
return self._np_random
def test_example_wrapper():
def test_example_wrapper(example_env):
"""Tests the gymnasium wrapper works as expected."""
env = ExampleEnv()
env = example_env
wrapper_env = ExampleWrapper(env)
assert env.metadata == wrapper_env.metadata
@@ -202,3 +208,45 @@ def test_get_set_wrapper_attr():
with pytest.raises(AttributeError):
env.unwrapped._disable_render_order_enforcing
assert env.get_wrapper_attr("_disable_render_order_enforcing") is True
class TestRandomSeeding:
@staticmethod
def test_nonempty_seed_retrieved_when_not_set(example_env):
assert example_env.np_random_seed is not None
assert isinstance(example_env.np_random_seed, int)
@staticmethod
def test_seed_set_at_reset_and_retrieved(example_env):
seed = 42
example_env.reset(seed=seed)
assert example_env.np_random_seed == seed
# resetting with seed=None means seed remains the same
example_env.reset(seed=None)
assert example_env.np_random_seed == seed
@staticmethod
def test_seed_cannot_be_set_directly(example_env):
with pytest.raises(AttributeError):
example_env.np_random_seed = 42
@staticmethod
def test_negative_seed_retrieved_when_seed_unknown(example_env):
rng, _ = np_random()
example_env.np_random = rng
# seed is unknown
assert example_env.np_random_seed == -1
@staticmethod
def test_seeding_works_in_wrapped_envs(example_env):
seed = 42
wrapper_env = ExampleWrapper(example_env)
wrapper_env.reset(seed=seed)
assert wrapper_env.np_random_seed == seed
# resetting with seed=None means seed remains the same
wrapper_env.reset(seed=None)
assert wrapper_env.np_random_seed == seed
# setting np_random directly makes seed unknown
rng, _ = np_random()
wrapper_env.np_random = rng
assert wrapper_env.np_random_seed == -1

View File

@@ -122,3 +122,46 @@ def test_final_obs_info(vectoriser):
)
env.close()
@pytest.fixture
def example_env_list():
"""Example vector environment."""
return [make_env("CartPole-v1", i) for i in range(4)]
@pytest.mark.parametrize(
"venv_constructor",
[
SyncVectorEnv,
partial(AsyncVectorEnv, shared_memory=True),
partial(AsyncVectorEnv, shared_memory=False),
],
)
def test_random_seeding_basics(venv_constructor, example_env_list):
seed = 42
vector_env = venv_constructor(example_env_list)
vector_env.reset(seed=seed)
assert vector_env.np_random_seed == tuple(
seed + i for i in range(vector_env.num_envs)
)
# resetting with seed=None means seed remains the same
vector_env.reset(seed=None)
assert vector_env.np_random_seed == tuple(
seed + i for i in range(vector_env.num_envs)
)
@pytest.mark.parametrize(
"venv_constructor",
[
SyncVectorEnv,
partial(AsyncVectorEnv, shared_memory=True),
partial(AsyncVectorEnv, shared_memory=False),
],
)
def test_random_seeds_set_at_retrieval(venv_constructor, example_env_list):
vector_env = venv_constructor(example_env_list)
assert len(set(vector_env.np_random_seed)) == vector_env.num_envs
# default seed starts at zero. Adjust or remove this test if the default seed changes
assert vector_env.np_random_seed == tuple(range(vector_env.num_envs))