Type hint fixes and added __all__ dunder (#321)

This commit is contained in:
Howard Huang
2023-02-12 12:42:32 -05:00
committed by GitHub
parent 9bc0bf308d
commit 79ae76ed1e
6 changed files with 23 additions and 7 deletions

View File

@@ -8,3 +8,6 @@ These are not intended as API functions, and will not remain stable over time.
# that verify that our dependencies are actually present. # that verify that our dependencies are actually present.
from gymnasium.utils.colorize import colorize from gymnasium.utils.colorize import colorize
from gymnasium.utils.ezpickle import EzPickle from gymnasium.utils.ezpickle import EzPickle
__all__ = ["colorize", "EzPickle"]

View File

@@ -4,9 +4,10 @@ import sys
import time import time
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from numpy.typing import NDArray
import gymnasium as gym import gymnasium as gym
from gymnasium import logger from gymnasium import logger
@@ -287,7 +288,7 @@ class AsyncVectorEnv(VectorEnv):
def step_wait( def step_wait(
self, timeout: Optional[Union[int, float]] = None self, timeout: Optional[Union[int, float]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]: ) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish. """Wait for the calls to :obj:`step` in each sub-environment to finish.
Args: Args:

View File

@@ -1,8 +1,9 @@
"""A synchronous vector environment.""" """A synchronous vector environment."""
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from numpy.typing import NDArray
from gymnasium import Env from gymnasium import Env
from gymnasium.spaces import Space from gymnasium.spaces import Space
@@ -132,7 +133,7 @@ class SyncVectorEnv(VectorEnv):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version.""" """Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
self._actions = iterate(self.action_space, actions) self._actions = iterate(self.action_space, actions)
def step_wait(self): def step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Steps through each of the environments returning the batched results. """Steps through each of the environments returning the batched results.
Returns: Returns:

View File

@@ -2,6 +2,7 @@
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
from numpy.typing import NDArray
import gymnasium as gym import gymnasium as gym
from gymnasium.vector.utils.spaces import batch_space from gymnasium.vector.utils.spaces import batch_space
@@ -146,7 +147,9 @@ class VectorEnv(gym.Env):
actions: The actions to take asynchronously actions: The actions to take asynchronously
""" """
def step_wait(self, **kwargs): def step_wait(
self, **kwargs
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Retrieves the results of a :meth:`step_async` call. """Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`. A call to this method must always be preceded by a call to :meth:`step_async`.
@@ -157,8 +160,11 @@ class VectorEnv(gym.Env):
Returns: Returns:
The results from the :meth:`step_async` call The results from the :meth:`step_async` call
""" """
raise NotImplementedError()
def step(self, actions): def step(
self, actions
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Take an action for each parallel environment. """Take an action for each parallel environment.
Args: Args:

View File

@@ -39,6 +39,7 @@ def test_reset_async_vector_env(shared_memory):
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
assert isinstance(infos, dict)
assert observations.dtype == env.observation_space.dtype assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape assert observations.shape == env.observation_space.shape

View File

@@ -29,6 +29,7 @@ def test_reset_sync_vector_env():
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
assert isinstance(infos, dict)
assert observations.dtype == env.observation_space.dtype assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape assert observations.shape == env.observation_space.shape
@@ -130,15 +131,18 @@ def test_custom_space_sync_vector_env():
assert isinstance(env.single_action_space, CustomSpace) assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple) assert isinstance(env.action_space, Tuple)
assert isinstance(infos, dict)
actions = ("action-2", "action-3", "action-5", "action-7") actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, terminateds, truncateds, _ = env.step(actions) step_observations, rewards, terminateds, truncateds, infos = env.step(actions)
env.close() env.close()
assert isinstance(env.single_observation_space, CustomSpace) assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple) assert isinstance(env.observation_space, Tuple)
assert isinstance(infos, dict)
assert isinstance(reset_observations, tuple) assert isinstance(reset_observations, tuple)
assert reset_observations == ("reset", "reset", "reset", "reset") assert reset_observations == ("reset", "reset", "reset", "reset")