Type hint fixes and added __all__ dunder (#321)

This commit is contained in:
Howard Huang
2023-02-12 12:42:32 -05:00
committed by GitHub
parent 9bc0bf308d
commit 79ae76ed1e
6 changed files with 23 additions and 7 deletions

View File

@@ -8,3 +8,6 @@ These are not intended as API functions, and will not remain stable over time.
# that verify that our dependencies are actually present.
from gymnasium.utils.colorize import colorize
from gymnasium.utils.ezpickle import EzPickle
__all__ = ["colorize", "EzPickle"]

View File

@@ -4,9 +4,10 @@ import sys
import time
from copy import deepcopy
from enum import Enum
from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium import logger
@@ -287,7 +288,7 @@ class AsyncVectorEnv(VectorEnv):
def step_wait(
self, timeout: Optional[Union[int, float]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]:
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
Args:

View File

@@ -1,8 +1,9 @@
"""A synchronous vector environment."""
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
from numpy.typing import NDArray
from gymnasium import Env
from gymnasium.spaces import Space
@@ -132,7 +133,7 @@ class SyncVectorEnv(VectorEnv):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
self._actions = iterate(self.action_space, actions)
def step_wait(self):
def step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Steps through each of the environments returning the batched results.
Returns:

View File

@@ -2,6 +2,7 @@
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium.vector.utils.spaces import batch_space
@@ -146,7 +147,9 @@ class VectorEnv(gym.Env):
actions: The actions to take asynchronously
"""
def step_wait(self, **kwargs):
def step_wait(
self, **kwargs
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`.
@@ -157,8 +160,11 @@ class VectorEnv(gym.Env):
Returns:
The results from the :meth:`step_async` call
"""
raise NotImplementedError()
def step(self, actions):
def step(
self, actions
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Take an action for each parallel environment.
Args:

View File

@@ -39,6 +39,7 @@ def test_reset_async_vector_env(shared_memory):
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert isinstance(infos, dict)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape

View File

@@ -29,6 +29,7 @@ def test_reset_sync_vector_env():
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert isinstance(infos, dict)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
@@ -130,15 +131,18 @@ def test_custom_space_sync_vector_env():
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)
assert isinstance(infos, dict)
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, terminateds, truncateds, _ = env.step(actions)
step_observations, rewards, terminateds, truncateds, infos = env.step(actions)
env.close()
assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple)
assert isinstance(infos, dict)
assert isinstance(reset_observations, tuple)
assert reset_observations == ("reset", "reset", "reset", "reset")