mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Made readout of seed possible in env (#889)
This commit is contained in:
committed by
GitHub
parent
9e83d5442c
commit
b3f0361f91
@@ -54,6 +54,7 @@ title: Env
|
||||
|
||||
.. autoproperty:: gymnasium.Env.unwrapped
|
||||
.. autoproperty:: gymnasium.Env.np_random
|
||||
.. autoproperty:: gymnasium.Env.np_random_seed
|
||||
```
|
||||
|
||||
## Implementing environments
|
||||
|
@@ -67,6 +67,7 @@ vector/utils
|
||||
```{eval-rst}
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.unwrapped
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.np_random_seed
|
||||
```
|
||||
|
||||
## Making Vector Environments
|
||||
|
@@ -11,3 +11,10 @@
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.get_attr
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.set_attr
|
||||
```
|
||||
|
||||
### Additional Methods
|
||||
|
||||
```{eval-rst}
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.np_random_seed
|
||||
```
|
||||
|
@@ -11,3 +11,10 @@
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.get_attr
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.set_attr
|
||||
```
|
||||
|
||||
### Additional Methods
|
||||
|
||||
```{eval-rst}
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.np_random_seed
|
||||
```
|
||||
|
@@ -47,5 +47,6 @@ wrappers/reward_wrappers
|
||||
.. autoproperty:: gymnasium.Wrapper.spec
|
||||
.. autoproperty:: gymnasium.Wrapper.metadata
|
||||
.. autoproperty:: gymnasium.Wrapper.np_random
|
||||
.. autoproperty:: gymnasium.Wrapper.np_random_seed
|
||||
.. autoproperty:: gymnasium.Wrapper.unwrapped
|
||||
```
|
||||
|
@@ -66,6 +66,8 @@ class Env(Generic[ObsType, ActType]):
|
||||
|
||||
# Created
|
||||
_np_random: np.random.Generator | None = None
|
||||
# will be set to the "invalid" value -1 if the seed of the currently set rng is unknown
|
||||
_np_random_seed: int | None = None
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
@@ -130,10 +132,12 @@ class Env(Generic[ObsType, ActType]):
|
||||
The ``return_info`` parameter was removed and now info is expected to be returned.
|
||||
|
||||
Args:
|
||||
seed (optional int): The seed that is used to initialize the environment's PRNG (`np_random`).
|
||||
seed (optional int): The seed that is used to initialize the environment's PRNG (`np_random`) and
|
||||
the read-only attribute `np_random_seed`.
|
||||
If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed,
|
||||
a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
|
||||
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset.
|
||||
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset
|
||||
and the env's :attr:`np_random_seed` will *not* be altered.
|
||||
If you pass an integer, the PRNG will be reset even if it already exists.
|
||||
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
|
||||
Please refer to the minimal example above to see this paradigm in action.
|
||||
@@ -148,7 +152,7 @@ class Env(Generic[ObsType, ActType]):
|
||||
"""
|
||||
# Initialize the RNG if the seed is manually passed
|
||||
if seed is not None:
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
self._np_random, self._np_random_seed = seeding.np_random(seed)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment.
|
||||
@@ -201,6 +205,20 @@ class Env(Generic[ObsType, ActType]):
|
||||
"""
|
||||
return self
|
||||
|
||||
@property
|
||||
def np_random_seed(self) -> int:
|
||||
"""Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
|
||||
|
||||
If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
|
||||
the seed will take the value -1.
|
||||
|
||||
Returns:
|
||||
int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
|
||||
"""
|
||||
if self._np_random_seed is None:
|
||||
self._np_random, self._np_random_seed = seeding.np_random()
|
||||
return self._np_random_seed
|
||||
|
||||
@property
|
||||
def np_random(self) -> np.random.Generator:
|
||||
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
|
||||
@@ -209,12 +227,20 @@ class Env(Generic[ObsType, ActType]):
|
||||
Instances of `np.random.Generator`
|
||||
"""
|
||||
if self._np_random is None:
|
||||
self._np_random, _ = seeding.np_random()
|
||||
self._np_random, self._np_random_seed = seeding.np_random()
|
||||
return self._np_random
|
||||
|
||||
@np_random.setter
|
||||
def np_random(self, value: np.random.Generator):
|
||||
"""Sets the environment's internal :attr:`_np_random` with the user-provided Generator.
|
||||
|
||||
Since it is generally not possible to extract a seed from an instance of a random number generator,
|
||||
this will also set the :attr:`_np_random_seed` to `-1`, which is not valid as input for the creation
|
||||
of a numpy rng.
|
||||
"""
|
||||
self._np_random = value
|
||||
# Setting a numpy rng with -1 will cause a ValueError
|
||||
self._np_random_seed = -1
|
||||
|
||||
def __str__(self):
|
||||
"""Returns a string of the environment with :attr:`spec` id's if :attr:`spec.
|
||||
@@ -303,6 +329,11 @@ class Wrapper(
|
||||
"""Closes the wrapper and :attr:`env`."""
|
||||
return self.env.close()
|
||||
|
||||
@property
|
||||
def np_random_seed(self) -> int | None:
|
||||
"""Returns the base enviroment's :attr:`np_random_seed`."""
|
||||
return self.env.np_random_seed
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> Env[ObsType, ActType]:
|
||||
"""Returns the base environment of the wrapper.
|
||||
|
@@ -197,6 +197,16 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state = AsyncState.DEFAULT
|
||||
self._check_spaces()
|
||||
|
||||
@property
|
||||
def np_random_seed(self) -> tuple[int, ...]:
|
||||
"""Returns the seeds of the wrapped envs."""
|
||||
return self.get_attr("np_random_seed")
|
||||
|
||||
@property
|
||||
def np_random(self) -> tuple[np.random.Generator, ...]:
|
||||
"""Returns the numpy random number generators of the wrapped envs."""
|
||||
return self.get_attr("np_random")
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
@@ -240,7 +250,9 @@ class AsyncVectorEnv(VectorEnv):
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
assert (
|
||||
len(seed) == self.num_envs
|
||||
), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
@@ -472,7 +484,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
return results
|
||||
|
||||
def get_attr(self, name: str):
|
||||
def get_attr(self, name: str) -> tuple[Any, ...]:
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Args:
|
||||
|
@@ -100,6 +100,16 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
|
||||
@property
|
||||
def np_random_seed(self) -> tuple[int, ...]:
|
||||
"""Returns the seeds of the wrapped envs."""
|
||||
return self.get_attr("np_random_seed")
|
||||
|
||||
@property
|
||||
def np_random(self) -> tuple[np.random.Generator, ...]:
|
||||
"""Returns the numpy random number generators of the wrapped envs."""
|
||||
return self.get_attr("np_random")
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
@@ -122,7 +132,9 @@ class SyncVectorEnv(VectorEnv):
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
assert (
|
||||
len(seed) == self.num_envs
|
||||
), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."
|
||||
|
||||
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
@@ -211,7 +223,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
return tuple(results)
|
||||
|
||||
def get_attr(self, name: str) -> Any:
|
||||
def get_attr(self, name: str) -> tuple[Any, ...]:
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Args:
|
||||
|
@@ -104,17 +104,18 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
num_envs: int
|
||||
|
||||
_np_random: np.random.Generator | None = None
|
||||
_np_random_seed: int | None = None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
seed: int | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
|
||||
"""Reset all parallel environments and return a batch of initial observations and info.
|
||||
|
||||
Args:
|
||||
seed: The environment reset seeds
|
||||
seed: The environment reset seed
|
||||
options: If to return the options
|
||||
|
||||
Returns:
|
||||
@@ -133,7 +134,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
{}
|
||||
"""
|
||||
if seed is not None:
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
self._np_random, self._np_random_seed = seeding.np_random(seed)
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
@@ -210,6 +211,20 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def np_random_seed(self) -> int | None:
|
||||
"""Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
|
||||
|
||||
If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
|
||||
the seed will take the value -1.
|
||||
|
||||
Returns:
|
||||
int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
|
||||
"""
|
||||
if self._np_random_seed is None:
|
||||
self._np_random, self._np_random_seed = seeding.np_random()
|
||||
return self._np_random_seed
|
||||
|
||||
@property
|
||||
def np_random(self) -> np.random.Generator:
|
||||
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
|
||||
@@ -218,12 +233,13 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
Instances of `np.random.Generator`
|
||||
"""
|
||||
if self._np_random is None:
|
||||
self._np_random, seed = seeding.np_random()
|
||||
self._np_random, self._np_random_seed = seeding.np_random()
|
||||
return self._np_random
|
||||
|
||||
@np_random.setter
|
||||
def np_random(self, value: np.random.Generator):
|
||||
self._np_random = value
|
||||
self._np_random_seed = -1
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
@@ -430,6 +446,19 @@ class VectorWrapper(VectorEnv):
|
||||
"""Returns the `render_mode` from the base environment."""
|
||||
return self.env.render_mode
|
||||
|
||||
@property
|
||||
def np_random(self) -> np.random.Generator:
|
||||
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
|
||||
|
||||
Returns:
|
||||
Instances of `np.random.Generator`
|
||||
"""
|
||||
return self.env.np_random
|
||||
|
||||
@np_random.setter
|
||||
def np_random(self, value: np.random.Generator):
|
||||
self.env.np_random = value
|
||||
|
||||
|
||||
class VectorObservationWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the observation.
|
||||
|
@@ -12,6 +12,7 @@ from gymnasium import ActionWrapper, Env, ObservationWrapper, RewardWrapper, Wra
|
||||
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
|
||||
from gymnasium.spaces import Box
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.utils.seeding import np_random
|
||||
from gymnasium.wrappers import OrderEnforcing
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
@@ -37,17 +38,22 @@ class ExampleEnv(Env):
|
||||
options: dict | None = None,
|
||||
) -> tuple[ObsType, dict]:
|
||||
"""Resets the environment."""
|
||||
super().reset(seed=seed, options=options)
|
||||
return 0, {}
|
||||
|
||||
|
||||
def test_example_env():
|
||||
"""Tests a gymnasium environment."""
|
||||
env = ExampleEnv()
|
||||
@pytest.fixture
|
||||
def example_env():
|
||||
return ExampleEnv()
|
||||
|
||||
assert env.metadata == {"render_modes": []}
|
||||
assert env.render_mode is None
|
||||
assert env.spec is None
|
||||
assert env._np_random is None # pyright: ignore [reportPrivateUsage]
|
||||
|
||||
def test_example_env(example_env):
|
||||
"""Tests a gymnasium environment."""
|
||||
|
||||
assert example_env.metadata == {"render_modes": []}
|
||||
assert example_env.render_mode is None
|
||||
assert example_env.spec is None
|
||||
assert example_env._np_random is None # pyright: ignore [reportPrivateUsage]
|
||||
|
||||
|
||||
class ExampleWrapper(Wrapper):
|
||||
@@ -77,9 +83,9 @@ class ExampleWrapper(Wrapper):
|
||||
return self._np_random
|
||||
|
||||
|
||||
def test_example_wrapper():
|
||||
def test_example_wrapper(example_env):
|
||||
"""Tests the gymnasium wrapper works as expected."""
|
||||
env = ExampleEnv()
|
||||
env = example_env
|
||||
wrapper_env = ExampleWrapper(env)
|
||||
|
||||
assert env.metadata == wrapper_env.metadata
|
||||
@@ -202,3 +208,45 @@ def test_get_set_wrapper_attr():
|
||||
with pytest.raises(AttributeError):
|
||||
env.unwrapped._disable_render_order_enforcing
|
||||
assert env.get_wrapper_attr("_disable_render_order_enforcing") is True
|
||||
|
||||
|
||||
class TestRandomSeeding:
|
||||
@staticmethod
|
||||
def test_nonempty_seed_retrieved_when_not_set(example_env):
|
||||
assert example_env.np_random_seed is not None
|
||||
assert isinstance(example_env.np_random_seed, int)
|
||||
|
||||
@staticmethod
|
||||
def test_seed_set_at_reset_and_retrieved(example_env):
|
||||
seed = 42
|
||||
example_env.reset(seed=seed)
|
||||
assert example_env.np_random_seed == seed
|
||||
# resetting with seed=None means seed remains the same
|
||||
example_env.reset(seed=None)
|
||||
assert example_env.np_random_seed == seed
|
||||
|
||||
@staticmethod
|
||||
def test_seed_cannot_be_set_directly(example_env):
|
||||
with pytest.raises(AttributeError):
|
||||
example_env.np_random_seed = 42
|
||||
|
||||
@staticmethod
|
||||
def test_negative_seed_retrieved_when_seed_unknown(example_env):
|
||||
rng, _ = np_random()
|
||||
example_env.np_random = rng
|
||||
# seed is unknown
|
||||
assert example_env.np_random_seed == -1
|
||||
|
||||
@staticmethod
|
||||
def test_seeding_works_in_wrapped_envs(example_env):
|
||||
seed = 42
|
||||
wrapper_env = ExampleWrapper(example_env)
|
||||
wrapper_env.reset(seed=seed)
|
||||
assert wrapper_env.np_random_seed == seed
|
||||
# resetting with seed=None means seed remains the same
|
||||
wrapper_env.reset(seed=None)
|
||||
assert wrapper_env.np_random_seed == seed
|
||||
# setting np_random directly makes seed unknown
|
||||
rng, _ = np_random()
|
||||
wrapper_env.np_random = rng
|
||||
assert wrapper_env.np_random_seed == -1
|
||||
|
@@ -122,3 +122,46 @@ def test_final_obs_info(vectoriser):
|
||||
)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_env_list():
|
||||
"""Example vector environment."""
|
||||
return [make_env("CartPole-v1", i) for i in range(4)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"venv_constructor",
|
||||
[
|
||||
SyncVectorEnv,
|
||||
partial(AsyncVectorEnv, shared_memory=True),
|
||||
partial(AsyncVectorEnv, shared_memory=False),
|
||||
],
|
||||
)
|
||||
def test_random_seeding_basics(venv_constructor, example_env_list):
|
||||
seed = 42
|
||||
vector_env = venv_constructor(example_env_list)
|
||||
vector_env.reset(seed=seed)
|
||||
assert vector_env.np_random_seed == tuple(
|
||||
seed + i for i in range(vector_env.num_envs)
|
||||
)
|
||||
# resetting with seed=None means seed remains the same
|
||||
vector_env.reset(seed=None)
|
||||
assert vector_env.np_random_seed == tuple(
|
||||
seed + i for i in range(vector_env.num_envs)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"venv_constructor",
|
||||
[
|
||||
SyncVectorEnv,
|
||||
partial(AsyncVectorEnv, shared_memory=True),
|
||||
partial(AsyncVectorEnv, shared_memory=False),
|
||||
],
|
||||
)
|
||||
def test_random_seeds_set_at_retrieval(venv_constructor, example_env_list):
|
||||
vector_env = venv_constructor(example_env_list)
|
||||
assert len(set(vector_env.np_random_seed)) == vector_env.num_envs
|
||||
# default seed starts at zero. Adjust or remove this test if the default seed changes
|
||||
assert vector_env.np_random_seed == tuple(range(vector_env.num_envs))
|
||||
|
Reference in New Issue
Block a user