mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Type hint fixes and added __all__ dunder (#321)
This commit is contained in:
@@ -8,3 +8,6 @@ These are not intended as API functions, and will not remain stable over time.
|
|||||||
# that verify that our dependencies are actually present.
|
# that verify that our dependencies are actually present.
|
||||||
from gymnasium.utils.colorize import colorize
|
from gymnasium.utils.colorize import colorize
|
||||||
from gymnasium.utils.ezpickle import EzPickle
|
from gymnasium.utils.ezpickle import EzPickle
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["colorize", "EzPickle"]
|
||||||
|
@@ -4,9 +4,10 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import logger
|
from gymnasium import logger
|
||||||
@@ -287,7 +288,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def step_wait(
|
def step_wait(
|
||||||
self, timeout: Optional[Union[int, float]] = None
|
self, timeout: Optional[Union[int, float]] = None
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]:
|
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
"""A synchronous vector environment."""
|
"""A synchronous vector environment."""
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union
|
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from gymnasium import Env
|
from gymnasium import Env
|
||||||
from gymnasium.spaces import Space
|
from gymnasium.spaces import Space
|
||||||
@@ -132,7 +133,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
||||||
self._actions = iterate(self.action_space, actions)
|
self._actions = iterate(self.action_space, actions)
|
||||||
|
|
||||||
def step_wait(self):
|
def step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||||
"""Steps through each of the environments returning the batched results.
|
"""Steps through each of the environments returning the batched results.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.vector.utils.spaces import batch_space
|
from gymnasium.vector.utils.spaces import batch_space
|
||||||
@@ -146,7 +147,9 @@ class VectorEnv(gym.Env):
|
|||||||
actions: The actions to take asynchronously
|
actions: The actions to take asynchronously
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def step_wait(self, **kwargs):
|
def step_wait(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||||
"""Retrieves the results of a :meth:`step_async` call.
|
"""Retrieves the results of a :meth:`step_async` call.
|
||||||
|
|
||||||
A call to this method must always be preceded by a call to :meth:`step_async`.
|
A call to this method must always be preceded by a call to :meth:`step_async`.
|
||||||
@@ -157,8 +160,11 @@ class VectorEnv(gym.Env):
|
|||||||
Returns:
|
Returns:
|
||||||
The results from the :meth:`step_async` call
|
The results from the :meth:`step_async` call
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def step(self, actions):
|
def step(
|
||||||
|
self, actions
|
||||||
|
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||||
"""Take an action for each parallel environment.
|
"""Take an action for each parallel environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -39,6 +39,7 @@ def test_reset_async_vector_env(shared_memory):
|
|||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
assert isinstance(observations, np.ndarray)
|
assert isinstance(observations, np.ndarray)
|
||||||
|
assert isinstance(infos, dict)
|
||||||
assert observations.dtype == env.observation_space.dtype
|
assert observations.dtype == env.observation_space.dtype
|
||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
assert observations.shape == env.observation_space.shape
|
assert observations.shape == env.observation_space.shape
|
||||||
|
@@ -29,6 +29,7 @@ def test_reset_sync_vector_env():
|
|||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
assert isinstance(observations, np.ndarray)
|
assert isinstance(observations, np.ndarray)
|
||||||
|
assert isinstance(infos, dict)
|
||||||
assert observations.dtype == env.observation_space.dtype
|
assert observations.dtype == env.observation_space.dtype
|
||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
assert observations.shape == env.observation_space.shape
|
assert observations.shape == env.observation_space.shape
|
||||||
@@ -130,15 +131,18 @@ def test_custom_space_sync_vector_env():
|
|||||||
|
|
||||||
assert isinstance(env.single_action_space, CustomSpace)
|
assert isinstance(env.single_action_space, CustomSpace)
|
||||||
assert isinstance(env.action_space, Tuple)
|
assert isinstance(env.action_space, Tuple)
|
||||||
|
assert isinstance(infos, dict)
|
||||||
|
|
||||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||||
step_observations, rewards, terminateds, truncateds, _ = env.step(actions)
|
step_observations, rewards, terminateds, truncateds, infos = env.step(actions)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.single_observation_space, CustomSpace)
|
assert isinstance(env.single_observation_space, CustomSpace)
|
||||||
assert isinstance(env.observation_space, Tuple)
|
assert isinstance(env.observation_space, Tuple)
|
||||||
|
|
||||||
|
assert isinstance(infos, dict)
|
||||||
|
|
||||||
assert isinstance(reset_observations, tuple)
|
assert isinstance(reset_observations, tuple)
|
||||||
assert reset_observations == ("reset", "reset", "reset", "reset")
|
assert reset_observations == ("reset", "reset", "reset", "reset")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user