mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Add support for python 3.6 (#2836)
* Add support for python 3.6 * Add support for python 3.6 * Added check for python 3.6 to not install mujoco as no version exists * Fixed the install groups for python 3.6 * Re-added python 3.6 support for gym * black * Added support for dataclasses through dataclasses module in setup that backports the module * Fixed install requirements * Re-added dummy env spec with dataclasses * Changed type for compatability for python 3.6 * Added a python 3.6 warning * Fixed python 3.6 typing issue * Removed __future__ import annotation for python 3.6 support * Fixed python 3.6 typing
This commit is contained in:
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -6,7 +6,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.7', '3.8', '3.9', '3.10']
|
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- run: |
|
- run: |
|
||||||
|
@@ -41,7 +41,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
# TODO: remove `--keep-runtime-typing` option
|
# TODO: remove `--keep-runtime-typing` option
|
||||||
args: ["--py37-plus", "--keep-runtime-typing"]
|
args: ["--py36-plus", "--keep-runtime-typing"]
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyright
|
- id: pyright
|
||||||
|
28
gym/core.py
28
gym/core.py
@@ -1,13 +1,17 @@
|
|||||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||||
from __future__ import annotations
|
import sys
|
||||||
|
from typing import Generic, Optional, SupportsFloat, Tuple, TypeVar, Union
|
||||||
from typing import Generic, Optional, SupportsFloat, TypeVar, Union
|
|
||||||
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.logger import deprecation
|
from gym.logger import deprecation, warn
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
from gym.utils.seeding import RandomNumberGenerator
|
from gym.utils.seeding import RandomNumberGenerator
|
||||||
|
|
||||||
|
if sys.version_info == (3, 6):
|
||||||
|
warn(
|
||||||
|
"Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+"
|
||||||
|
)
|
||||||
|
|
||||||
ObsType = TypeVar("ObsType")
|
ObsType = TypeVar("ObsType")
|
||||||
ActType = TypeVar("ActType")
|
ActType = TypeVar("ActType")
|
||||||
|
|
||||||
@@ -62,7 +66,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
def np_random(self, value: RandomNumberGenerator):
|
def np_random(self, value: RandomNumberGenerator):
|
||||||
self._np_random = value
|
self._np_random = value
|
||||||
|
|
||||||
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||||
"""Run one timestep of the environment's dynamics.
|
"""Run one timestep of the environment's dynamics.
|
||||||
|
|
||||||
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
|
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
|
||||||
@@ -92,7 +96,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Union[ObsType, tuple[ObsType, dict]]:
|
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
"""Resets the environment to an initial state and returns the initial observation.
|
"""Resets the environment to an initial state and returns the initial observation.
|
||||||
|
|
||||||
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
|
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
|
||||||
@@ -201,7 +205,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self) -> Env:
|
def unwrapped(self) -> "Env":
|
||||||
"""Returns the base non-wrapped environment.
|
"""Returns the base non-wrapped environment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -248,7 +252,7 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
|
|
||||||
self._action_space: Optional[spaces.Space] = None
|
self._action_space: Optional[spaces.Space] = None
|
||||||
self._observation_space: Optional[spaces.Space] = None
|
self._observation_space: Optional[spaces.Space] = None
|
||||||
self._reward_range: Optional[tuple[SupportsFloat, SupportsFloat]] = None
|
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
|
||||||
self._metadata: Optional[dict] = None
|
self._metadata: Optional[dict] = None
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
@@ -290,14 +294,14 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
self._observation_space = space
|
self._observation_space = space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
|
def reward_range(self) -> Tuple[SupportsFloat, SupportsFloat]:
|
||||||
"""Return the reward range of the environment."""
|
"""Return the reward range of the environment."""
|
||||||
if self._reward_range is None:
|
if self._reward_range is None:
|
||||||
return self.env.reward_range
|
return self.env.reward_range
|
||||||
return self._reward_range
|
return self._reward_range
|
||||||
|
|
||||||
@reward_range.setter
|
@reward_range.setter
|
||||||
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
|
def reward_range(self, value: Tuple[SupportsFloat, SupportsFloat]):
|
||||||
self._reward_range = value
|
self._reward_range = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -311,11 +315,11 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
def metadata(self, value):
|
def metadata(self, value):
|
||||||
self._metadata = value
|
self._metadata = value
|
||||||
|
|
||||||
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||||
"""Steps through the environment with action."""
|
"""Steps through the environment with action."""
|
||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
|
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
"""Resets the environment with kwargs."""
|
"""Resets the environment with kwargs."""
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
@@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import difflib
|
import difflib
|
||||||
@@ -10,12 +8,14 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
SupportsFloat,
|
SupportsFloat,
|
||||||
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
@@ -34,15 +34,11 @@ else:
|
|||||||
if sys.version_info >= (3, 8):
|
if sys.version_info >= (3, 8):
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
else:
|
else:
|
||||||
|
from typing_extensions import Literal
|
||||||
class Literal(str):
|
|
||||||
def __class_getitem__(cls, item):
|
|
||||||
return Any
|
|
||||||
|
|
||||||
|
|
||||||
from gym import Env, error, logger
|
from gym import Env, error, logger
|
||||||
|
|
||||||
ENV_ID_RE: re.Pattern = re.compile(
|
ENV_ID_RE = re.compile(
|
||||||
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
|
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,7 +50,7 @@ def load(name: str) -> type:
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def parse_env_id(id: str) -> tuple[Optional[str], str, Optional[int]]:
|
def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
||||||
"""Parse environment ID string format.
|
"""Parse environment ID string format.
|
||||||
|
|
||||||
This format is true today, but it's *not* an official spec.
|
This format is true today, but it's *not* an official spec.
|
||||||
@@ -241,7 +237,7 @@ def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
|
|||||||
|
|
||||||
|
|
||||||
def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
|
def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
|
||||||
version: list[int] = [
|
version: List[int] = [
|
||||||
spec_.version
|
spec_.version
|
||||||
for spec_ in registry.values()
|
for spec_ in registry.values()
|
||||||
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
||||||
@@ -302,39 +298,39 @@ def load_env_plugins(entry_point: str = "gym.envs") -> None:
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
|
|
||||||
# Box2d
|
# Box2d
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["CarRacing-v1", "CarRacingDomainRandomize-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
def make(id: Literal["CarRacing-v1", "CarRacingDomainRandomize-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
||||||
|
|
||||||
# Toy Text
|
# Toy Text
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
||||||
|
|
||||||
# Mujoco
|
# Mujoco
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
@@ -415,7 +411,7 @@ class EnvRegistry(dict):
|
|||||||
|
|
||||||
|
|
||||||
# Global registry of environments. Meant to be accessed through `register` and `make`
|
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||||
registry: dict[str, EnvSpec] = EnvRegistry()
|
registry: Dict[str, EnvSpec] = EnvRegistry()
|
||||||
current_namespace: Optional[str] = None
|
current_namespace: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -522,7 +518,7 @@ def register(id: str, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def make(
|
def make(
|
||||||
id: str | EnvSpec,
|
id: Union[str, EnvSpec],
|
||||||
max_episode_steps: Optional[int] = None,
|
max_episode_steps: Optional[int] = None,
|
||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
disable_env_checker: bool = False,
|
disable_env_checker: bool = False,
|
||||||
|
@@ -1,9 +1,7 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -31,7 +29,7 @@ MAPS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_random_map(size: int = 8, p: float = 0.8) -> list[str]:
|
def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
|
||||||
"""Generates a random valid map (one that has a path from start to goal)
|
"""Generates a random valid map (one that has a path from start to goal)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
"""Implementation of a space that represents closed boxes in euclidean space."""
|
"""Implementation of a space that represents closed boxes in euclidean space."""
|
||||||
from __future__ import annotations
|
from typing import List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
||||||
|
|
||||||
from typing import Optional, Sequence, SupportsFloat, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -52,8 +50,8 @@ class Box(Space[np.ndarray]):
|
|||||||
low: Union[SupportsFloat, np.ndarray],
|
low: Union[SupportsFloat, np.ndarray],
|
||||||
high: Union[SupportsFloat, np.ndarray],
|
high: Union[SupportsFloat, np.ndarray],
|
||||||
shape: Optional[Sequence[int]] = None,
|
shape: Optional[Sequence[int]] = None,
|
||||||
dtype: type = np.float32,
|
dtype: Type = np.float32,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Box`.
|
r"""Constructor of :class:`Box`.
|
||||||
|
|
||||||
@@ -105,7 +103,7 @@ class Box(Space[np.ndarray]):
|
|||||||
assert isinstance(high, np.ndarray)
|
assert isinstance(high, np.ndarray)
|
||||||
assert high.shape == shape, "high.shape doesn't match provided shape"
|
assert high.shape == shape, "high.shape doesn't match provided shape"
|
||||||
|
|
||||||
self._shape: tuple[int, ...] = shape
|
self._shape: Tuple[int, ...] = shape
|
||||||
|
|
||||||
low_precision = get_precision(low.dtype)
|
low_precision = get_precision(low.dtype)
|
||||||
high_precision = get_precision(high.dtype)
|
high_precision = get_precision(high.dtype)
|
||||||
@@ -121,7 +119,7 @@ class Box(Space[np.ndarray]):
|
|||||||
super().__init__(self.shape, self.dtype, seed)
|
super().__init__(self.shape, self.dtype, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> Tuple[int, ...]:
|
||||||
"""Has stricter type than gym.Space - never None."""
|
"""Has stricter type than gym.Space - never None."""
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
@@ -210,7 +208,7 @@ class Box(Space[np.ndarray]):
|
|||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return np.array(sample_n).tolist()
|
return np.array(sample_n).tolist()
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> list[np.ndarray]:
|
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> List[np.ndarray]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [np.asarray(sample) for sample in sample_n]
|
return [np.asarray(sample) for sample in sample_n]
|
||||||
|
|
||||||
@@ -278,7 +276,7 @@ def get_precision(dtype) -> SupportsFloat:
|
|||||||
def _broadcast(
|
def _broadcast(
|
||||||
value: Union[SupportsFloat, np.ndarray],
|
value: Union[SupportsFloat, np.ndarray],
|
||||||
dtype,
|
dtype,
|
||||||
shape: tuple[int, ...],
|
shape: Tuple[int, ...],
|
||||||
inf_sign: str,
|
inf_sign: str,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Handle infinite bounds and broadcast at the same time if needed."""
|
"""Handle infinite bounds and broadcast at the same time if needed."""
|
||||||
|
@@ -1,10 +1,8 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
|
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Dict as TypingDict
|
from typing import Dict as TypingDict
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -53,8 +51,8 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spaces: Optional[dict[str, Space]] = None,
|
spaces: Optional[TypingDict[str, Space]] = None,
|
||||||
seed: Optional[dict | int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[Union[dict, int, seeding.RandomNumberGenerator]] = None,
|
||||||
**spaces_kwargs: Space,
|
**spaces_kwargs: Space,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`Dict` space.
|
"""Constructor of :class:`Dict` space.
|
||||||
@@ -101,7 +99,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
None, None, seed # type: ignore
|
None, None, seed # type: ignore
|
||||||
) # None for shape and dtype, since it'll require special handling
|
) # None for shape and dtype, since it'll require special handling
|
||||||
|
|
||||||
def seed(self, seed: Optional[dict | int] = None) -> list:
|
def seed(self, seed: Optional[Union[dict, int]] = None) -> list:
|
||||||
"""Seed the PRNG of this space and all subspaces."""
|
"""Seed the PRNG of this space and all subspaces."""
|
||||||
seeds = []
|
seeds = []
|
||||||
if isinstance(seed, dict):
|
if isinstance(seed, dict):
|
||||||
@@ -188,9 +186,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
for key, space in self.spaces.items()
|
for key, space in self.spaces.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: dict[str, list]) -> list:
|
def from_jsonable(self, sample_n: TypingDict[str, list]) -> list:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
dict_of_list: dict[str, list] = {}
|
dict_of_list: TypingDict[str, list] = {}
|
||||||
for key, space in self.spaces.items():
|
for key, space in self.spaces.items():
|
||||||
dict_of_list[key] = space.from_jsonable(sample_n[key])
|
dict_of_list[key] = space.from_jsonable(sample_n[key])
|
||||||
ret = []
|
ret = []
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
"""Implementation of a space consisting of finitely many elements."""
|
"""Implementation of a space consisting of finitely many elements."""
|
||||||
from __future__ import annotations
|
from typing import Optional, Union
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -23,7 +21,7 @@ class Discrete(Space[int]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: int,
|
n: int,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
||||||
start: int = 0,
|
start: int = 0,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Discrete` space.
|
r"""Constructor of :class:`Discrete` space.
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
|
"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
|
||||||
from __future__ import annotations
|
from typing import Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from typing import Optional, Sequence, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -29,7 +27,7 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: Union[np.ndarray, Sequence[int], int],
|
n: Union[np.ndarray, Sequence[int], int],
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`MultiBinary` space.
|
"""Constructor of :class:`MultiBinary` space.
|
||||||
|
|
||||||
@@ -49,7 +47,7 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
super().__init__(input_n, np.int8, seed)
|
super().__init__(input_n, np.int8, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> Tuple[int, ...]:
|
||||||
"""Has stricter type than gym.Space - never None."""
|
"""Has stricter type than gym.Space - never None."""
|
||||||
return self._shape # type: ignore
|
return self._shape # type: ignore
|
||||||
|
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
|
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
|
||||||
from __future__ import annotations
|
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from typing import Iterable, Optional, Sequence, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -31,9 +29,9 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
nvec: Union[np.ndarray, list[int]],
|
nvec: Union[np.ndarray, List[int]],
|
||||||
dtype=np.int64,
|
dtype=np.int64,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`MultiDiscrete` space.
|
"""Constructor of :class:`MultiDiscrete` space.
|
||||||
|
|
||||||
@@ -61,7 +59,7 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
super().__init__(self.nvec.shape, dtype, seed)
|
super().__init__(self.nvec.shape, dtype, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> Tuple[int, ...]:
|
||||||
"""Has stricter type than :class:`gym.Space` - never None."""
|
"""Has stricter type than :class:`gym.Space` - never None."""
|
||||||
return self._shape # type: ignore
|
return self._shape # type: ignore
|
||||||
|
|
||||||
|
@@ -1,7 +1,17 @@
|
|||||||
"""Implementation of the `Space` metaclass."""
|
"""Implementation of the `Space` metaclass."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Generic, Iterable, Mapping, Optional, Sequence, TypeVar
|
from typing import (
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -39,8 +49,8 @@ class Space(Generic[T_cov]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shape: Optional[Sequence[int]] = None,
|
shape: Optional[Sequence[int]] = None,
|
||||||
dtype: Optional[type | str | np.dtype] = None,
|
dtype: Optional[Union[Type, str, np.dtype]] = None,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`Space`.
|
"""Constructor of :class:`Space`.
|
||||||
|
|
||||||
@@ -67,7 +77,7 @@ class Space(Generic[T_cov]):
|
|||||||
return self._np_random # type: ignore ## self.seed() call guarantees right type.
|
return self._np_random # type: ignore ## self.seed() call guarantees right type.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Optional[tuple[int, ...]]:
|
def shape(self) -> Optional[Tuple[int, ...]]:
|
||||||
"""Return the shape of the space as an immutable property."""
|
"""Return the shape of the space as an immutable property."""
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
@@ -88,7 +98,7 @@ class Space(Generic[T_cov]):
|
|||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
return self.contains(x)
|
return self.contains(x)
|
||||||
|
|
||||||
def __setstate__(self, state: Iterable | Mapping):
|
def __setstate__(self, state: Union[Iterable, Mapping]):
|
||||||
"""Used when loading a pickled space.
|
"""Used when loading a pickled space.
|
||||||
|
|
||||||
This method was implemented explicitly to allow for loading of legacy states.
|
This method was implemented explicitly to allow for loading of legacy states.
|
||||||
@@ -119,7 +129,7 @@ class Space(Generic[T_cov]):
|
|||||||
# By default, assume identity is JSONable
|
# By default, assume identity is JSONable
|
||||||
return list(sample_n)
|
return list(sample_n)
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: list) -> list[T_cov]:
|
def from_jsonable(self, sample_n: list) -> List[T_cov]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
# By default, assume identity is JSONable
|
# By default, assume identity is JSONable
|
||||||
return sample_n
|
return sample_n
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||||
from __future__ import annotations
|
from typing import Iterable, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from typing import Iterable, Optional, Sequence
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -25,7 +23,7 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spaces: Iterable[Space],
|
spaces: Iterable[Space],
|
||||||
seed: Optional[int | list[int] | seeding.RandomNumberGenerator] = None,
|
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Tuple` space.
|
r"""Constructor of :class:`Tuple` space.
|
||||||
|
|
||||||
@@ -43,7 +41,7 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
), "Elements of the tuple must be instances of gym.Space"
|
), "Elements of the tuple must be instances of gym.Space"
|
||||||
super().__init__(None, None, seed) # type: ignore
|
super().__init__(None, None, seed) # type: ignore
|
||||||
|
|
||||||
def seed(self, seed: Optional[int | list[int]] = None) -> list:
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
|
||||||
"""Seed the PRNG of this space and all subspaces."""
|
"""Seed the PRNG of this space and all subspaces."""
|
||||||
seeds = []
|
seeds = []
|
||||||
|
|
||||||
|
@@ -3,8 +3,6 @@
|
|||||||
These functions mostly take care of flattening and unflattening elements of spaces
|
These functions mostly take care of flattening and unflattening elements of spaces
|
||||||
to facilitate their usage in learning code.
|
to facilitate their usage in learning code.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import operator as op
|
import operator as op
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import reduce, singledispatch
|
from functools import reduce, singledispatch
|
||||||
@@ -142,7 +140,9 @@ def unflatten(space: Space[T], x: np.ndarray) -> T:
|
|||||||
|
|
||||||
@unflatten.register(Box)
|
@unflatten.register(Box)
|
||||||
@unflatten.register(MultiBinary)
|
@unflatten.register(MultiBinary)
|
||||||
def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.ndarray:
|
def _unflatten_box_multibinary(
|
||||||
|
space: Union[Box, MultiBinary], x: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
"""Utilities of visualising an environment."""
|
"""Utilities of visualising an environment."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Callable, Dict, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pygame
|
import pygame
|
||||||
@@ -35,7 +33,7 @@ class PlayableGame:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: Env,
|
env: Env,
|
||||||
keys_to_action: Optional[dict[tuple[int], int]] = None,
|
keys_to_action: Optional[Dict[Tuple[int], int]] = None,
|
||||||
zoom: Optional[float] = None,
|
zoom: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
||||||
@@ -53,7 +51,7 @@ class PlayableGame:
|
|||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
def _get_relevant_keys(
|
def _get_relevant_keys(
|
||||||
self, keys_to_action: Optional[dict[tuple[int], int]] = None
|
self, keys_to_action: Optional[Dict[Tuple[int], int]] = None
|
||||||
) -> set:
|
) -> set:
|
||||||
if keys_to_action is None:
|
if keys_to_action is None:
|
||||||
if hasattr(self.env, "get_keys_to_action"):
|
if hasattr(self.env, "get_keys_to_action"):
|
||||||
@@ -68,7 +66,7 @@ class PlayableGame:
|
|||||||
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
||||||
return relevant_keys
|
return relevant_keys
|
||||||
|
|
||||||
def _get_video_size(self, zoom: Optional[float] = None) -> tuple[int, int]:
|
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
||||||
# TODO: this needs to be updated when the render API change goes through
|
# TODO: this needs to be updated when the render API change goes through
|
||||||
rendered = self.env.render(mode="rgb_array")
|
rendered = self.env.render(mode="rgb_array")
|
||||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||||
@@ -103,7 +101,7 @@ class PlayableGame:
|
|||||||
|
|
||||||
|
|
||||||
def display_arr(
|
def display_arr(
|
||||||
screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
|
screen: Surface, arr: np.ndarray, video_size: Tuple[int, int], transpose: bool
|
||||||
):
|
):
|
||||||
"""Displays a numpy array on screen.
|
"""Displays a numpy array on screen.
|
||||||
|
|
||||||
@@ -273,7 +271,7 @@ class PlayPlot:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, callback: callable, horizon_timesteps: int, plot_names: list[str]
|
self, callback: callable, horizon_timesteps: int, plot_names: List[str]
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`PlayPlot`.
|
"""Constructor of :class:`PlayPlot`.
|
||||||
|
|
||||||
|
@@ -1,10 +1,8 @@
|
|||||||
"""Set of random number generator functions: seeding, generator, hashing seeds."""
|
"""Set of random number generator functions: seeding, generator, hashing seeds."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -12,7 +10,7 @@ from gym import error
|
|||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
|
|
||||||
|
|
||||||
def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]:
|
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
|
||||||
"""Generates a random number generator from the seed and returns the Generator and seed.
|
"""Generates a random number generator from the seed and returns the Generator and seed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -216,7 +214,7 @@ def _bigint_from_bytes(bt: bytes) -> int:
|
|||||||
return accum
|
return accum
|
||||||
|
|
||||||
|
|
||||||
def _int_list_from_bigint(bigint: int) -> list[int]:
|
def _int_list_from_bigint(bigint: int) -> List[int]:
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
||||||
)
|
)
|
||||||
@@ -226,7 +224,7 @@ def _int_list_from_bigint(bigint: int) -> list[int]:
|
|||||||
elif bigint == 0:
|
elif bigint == 0:
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
ints: list[int] = []
|
ints: List[int] = []
|
||||||
while bigint > 0:
|
while bigint > 0:
|
||||||
bigint, mod = divmod(bigint, 2**32)
|
bigint, mod = divmod(bigint, 2**32)
|
||||||
ints.append(mod)
|
ints.append(mod)
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
"""Module for vector environments."""
|
"""Module for vector environments."""
|
||||||
from __future__ import annotations
|
from typing import Iterable, List, Optional, Union
|
||||||
|
|
||||||
from typing import Iterable, Optional, Union
|
|
||||||
|
|
||||||
from gym.vector.async_vector_env import AsyncVectorEnv
|
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||||
@@ -14,7 +12,7 @@ def make(
|
|||||||
id: str,
|
id: str,
|
||||||
num_envs: int = 1,
|
num_envs: int = 1,
|
||||||
asynchronous: bool = True,
|
asynchronous: bool = True,
|
||||||
wrappers: Optional[Union[callable, list[callable]]] = None,
|
wrappers: Optional[Union[callable, List[callable]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> VectorEnv:
|
) -> VectorEnv:
|
||||||
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||||
|
@@ -1,12 +1,10 @@
|
|||||||
"""An async vector environment."""
|
"""An async vector environment."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Sequence, Union
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -199,7 +197,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, list[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
@@ -250,7 +248,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Union[ObsType, tuple[ObsType, list[dict]]]:
|
) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
|
||||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -333,7 +331,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def step_wait(
|
def step_wait(
|
||||||
self, timeout: Optional[Union[int, float]] = None
|
self, timeout: Optional[Union[int, float]] = None
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[dict]]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
|
||||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
"""A synchronous vector environment."""
|
"""A synchronous vector environment."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Iterator, Optional, Sequence, Union
|
from typing import Any, Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -89,7 +87,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, list[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
"""Base class for vectorized environments."""
|
"""Base class for vectorized environments."""
|
||||||
from __future__ import annotations
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
from typing import Any, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -50,7 +48,7 @@ class VectorEnv(gym.Env):
|
|||||||
|
|
||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, list[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
@@ -68,7 +66,7 @@ class VectorEnv(gym.Env):
|
|||||||
|
|
||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, list[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
@@ -91,7 +89,7 @@ class VectorEnv(gym.Env):
|
|||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[Union[int, list[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
@@ -144,10 +142,10 @@ class VectorEnv(gym.Env):
|
|||||||
def call_async(self, name, *args, **kwargs):
|
def call_async(self, name, *args, **kwargs):
|
||||||
"""Calls a method name for each parallel environment asynchronously."""
|
"""Calls a method name for each parallel environment asynchronously."""
|
||||||
|
|
||||||
def call_wait(self, **kwargs) -> list[Any]:
|
def call_wait(self, **kwargs) -> List[Any]:
|
||||||
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
||||||
|
|
||||||
def call(self, name: str, *args, **kwargs) -> list[Any]:
|
def call(self, name: str, *args, **kwargs) -> List[Any]:
|
||||||
"""Call a method, or get a property, from each parallel environment.
|
"""Call a method, or get a property, from each parallel environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -1,6 +1,4 @@
|
|||||||
"""A wrapper for video recording environments by rolling it out, frame by frame."""
|
"""A wrapper for video recording environments by rolling it out, frame by frame."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
@@ -9,7 +7,7 @@ import shutil
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Optional, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -362,7 +360,7 @@ class ImageEncoder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
frame_shape: tuple[int, int, int],
|
frame_shape: Tuple[int, int, int],
|
||||||
frames_per_sec: int,
|
frames_per_sec: int,
|
||||||
output_frames_per_sec: int,
|
output_frames_per_sec: int,
|
||||||
):
|
):
|
||||||
|
@@ -1,10 +1,8 @@
|
|||||||
"""Wrapper for augmenting observations by pixel values."""
|
"""Wrapper for augmenting observations by pixel values."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -52,8 +50,8 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
|||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
pixels_only: bool = True,
|
pixels_only: bool = True,
|
||||||
render_kwargs: Optional[dict[str, dict[str, Any]]] = None,
|
render_kwargs: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
pixel_keys: tuple[str, ...] = ("pixels",),
|
pixel_keys: Tuple[str, ...] = ("pixels",),
|
||||||
):
|
):
|
||||||
"""Initializes a new pixel Wrapper.
|
"""Initializes a new pixel Wrapper.
|
||||||
|
|
||||||
|
@@ -14,6 +14,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin
|
|||||||
COPY . /usr/local/gym/
|
COPY . /usr/local/gym/
|
||||||
WORKDIR /usr/local/gym/
|
WORKDIR /usr/local/gym/
|
||||||
|
|
||||||
RUN pip install .[noatari] && pip install -r test_requirements.txt
|
RUN if [ python:$PYTHON_VERSION = "python:3.6.15" ] ; then pip install .[box2d,classic_control,toy_text,other] ; else pip install .[noatari] ; fi
|
||||||
|
RUN pip install -r test_requirements.txt
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
||||||
|
@@ -25,7 +25,7 @@ strict = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
typeCheckingMode = "basic"
|
typeCheckingMode = "basic"
|
||||||
pythonVersion = "3.7"
|
pythonVersion = "3.6"
|
||||||
typeshedPath = "typeshed"
|
typeshedPath = "typeshed"
|
||||||
enableTypeIgnoreComments = true
|
enableTypeIgnoreComments = true
|
||||||
|
|
||||||
|
6
setup.py
6
setup.py
@@ -55,8 +55,9 @@ setup(
|
|||||||
install_requires=[
|
install_requires=[
|
||||||
"numpy>=1.18.0",
|
"numpy>=1.18.0",
|
||||||
"cloudpickle>=1.2.0",
|
"cloudpickle>=1.2.0",
|
||||||
"importlib_metadata>=4.10.0; python_version < '3.10'",
|
"importlib_metadata>=4.8.0; python_version < '3.10'",
|
||||||
"gym_notices>=0.0.4",
|
"gym_notices>=0.0.4",
|
||||||
|
"dataclasses==0.8; python_version == '3.6'",
|
||||||
],
|
],
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
package_data={
|
package_data={
|
||||||
@@ -69,9 +70,10 @@ setup(
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
tests_require=["pytest", "mock"],
|
tests_require=["pytest", "mock"],
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.6",
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.6",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
|
@@ -1,3 +1,2 @@
|
|||||||
lz4~=3.1
|
lz4~=3.1
|
||||||
pytest~=6.2
|
pytest~=6.2
|
||||||
pytest-forked~=1.3
|
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
from gym import envs, logger
|
from gym import envs, logger
|
||||||
|
|
||||||
SKIP_MUJOCO_V3_WARNING_MESSAGE = (
|
SKIP_MUJOCO_V3_WARNING_MESSAGE = (
|
||||||
"Cannot run mujoco test because mujoco-py is not installed"
|
"Cannot run mujoco test because `mujoco-py` is not installed"
|
||||||
|
)
|
||||||
|
SKIP_MUJOCO_V4_WARNING_MESSAGE = (
|
||||||
|
"Cannot run mujoco test because `mujoco` is not installed"
|
||||||
)
|
)
|
||||||
|
|
||||||
skip_mujoco_v3 = False
|
skip_mujoco_v3 = False
|
||||||
@@ -10,13 +13,19 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
skip_mujoco_v3 = True
|
skip_mujoco_v3 = True
|
||||||
|
|
||||||
|
skip_mujoco_v4 = False
|
||||||
|
try:
|
||||||
|
import mujoco # noqa:F401
|
||||||
|
except ImportError:
|
||||||
|
skip_mujoco_v4 = True
|
||||||
|
|
||||||
|
|
||||||
def should_skip_env_spec_for_tests(spec):
|
def should_skip_env_spec_for_tests(spec):
|
||||||
# We skip tests for envs that require dependencies or are otherwise
|
# We skip tests for envs that require dependencies or are otherwise
|
||||||
# troublesome to run frequently
|
# troublesome to run frequently
|
||||||
ep = spec.entry_point
|
ep = spec.entry_point
|
||||||
# Skip mujoco tests for pull request CI
|
# Skip mujoco tests for pull request CI
|
||||||
if skip_mujoco_v3 and ep.startswith("gym.envs.mujoco"):
|
if (skip_mujoco_v3 or skip_mujoco_v4) and ep.startswith("gym.envs.mujoco"):
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
import gym.envs.atari # noqa:F401
|
import gym.envs.atari # noqa:F401
|
||||||
|
@@ -3,11 +3,11 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gym import envs
|
import gym
|
||||||
|
from gym import Env
|
||||||
from gym.envs.registration import EnvSpec
|
from gym.envs.registration import EnvSpec
|
||||||
from gym.spaces.box import Box
|
from gym.spaces.box import Box
|
||||||
from gym.spaces.discrete import Discrete
|
from gym.spaces.discrete import Discrete
|
||||||
from gym.spaces.space import Space
|
|
||||||
from tests.envs.spec_list import (
|
from tests.envs.spec_list import (
|
||||||
SKIP_MUJOCO_V3_WARNING_MESSAGE,
|
SKIP_MUJOCO_V3_WARNING_MESSAGE,
|
||||||
skip_mujoco_v3,
|
skip_mujoco_v3,
|
||||||
@@ -17,17 +17,20 @@ from tests.envs.spec_list import (
|
|||||||
ENVIRONMENT_IDS = ("HalfCheetah-v2",)
|
ENVIRONMENT_IDS = ("HalfCheetah-v2",)
|
||||||
|
|
||||||
|
|
||||||
def make_envs_by_action_space_type(spec_list: List[EnvSpec], action_space: Space):
|
def filters_envs_action_space_type(
|
||||||
|
env_spec_list: List[EnvSpec], action_space: type
|
||||||
|
) -> List[Env]:
|
||||||
"""Make environments of specific action_space type.
|
"""Make environments of specific action_space type.
|
||||||
This function returns a filtered list of environment from the
|
|
||||||
spec_list that matches the action_space type.
|
This function returns a filtered list of environment from the spec_list that matches the action_space type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
spec_list (list): list of registered environments' specification
|
env_spec_list (list): list of registered environments' specification
|
||||||
action_space (gym.spaces.Space): action_space type
|
action_space (gym.spaces.Space): action_space type
|
||||||
"""
|
"""
|
||||||
filtered_envs = []
|
filtered_envs = []
|
||||||
for spec in spec_list:
|
for spec in env_spec_list:
|
||||||
env = envs.make(spec.id)
|
env = gym.make(spec.id)
|
||||||
if isinstance(env.action_space, action_space):
|
if isinstance(env.action_space, action_space):
|
||||||
filtered_envs.append(env)
|
filtered_envs.append(env)
|
||||||
return filtered_envs
|
return filtered_envs
|
||||||
@@ -36,7 +39,7 @@ def make_envs_by_action_space_type(spec_list: List[EnvSpec], action_space: Space
|
|||||||
@pytest.mark.skipif(skip_mujoco_v3, reason=SKIP_MUJOCO_V3_WARNING_MESSAGE)
|
@pytest.mark.skipif(skip_mujoco_v3, reason=SKIP_MUJOCO_V3_WARNING_MESSAGE)
|
||||||
@pytest.mark.parametrize("environment_id", ENVIRONMENT_IDS)
|
@pytest.mark.parametrize("environment_id", ENVIRONMENT_IDS)
|
||||||
def test_serialize_deserialize(environment_id):
|
def test_serialize_deserialize(environment_id):
|
||||||
env = envs.make(environment_id)
|
env = gym.make(environment_id)
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Action dimension mismatch"):
|
with pytest.raises(ValueError, match="Action dimension mismatch"):
|
||||||
@@ -46,7 +49,7 @@ def test_serialize_deserialize(environment_id):
|
|||||||
env.step(0.1)
|
env.step(0.1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env", make_envs_by_action_space_type(spec_list, Discrete))
|
@pytest.mark.parametrize("env", filters_envs_action_space_type(spec_list, Discrete))
|
||||||
def test_discrete_actions_out_of_bound(env):
|
def test_discrete_actions_out_of_bound(env):
|
||||||
"""Test out of bound actions in Discrete action_space.
|
"""Test out of bound actions in Discrete action_space.
|
||||||
In discrete action_space environments, `out-of-bound`
|
In discrete action_space environments, `out-of-bound`
|
||||||
@@ -65,7 +68,7 @@ def test_discrete_actions_out_of_bound(env):
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("env", "seed"),
|
("env", "seed"),
|
||||||
[(env, 42) for env in make_envs_by_action_space_type(spec_list, Box)],
|
[(env, 42) for env in filters_envs_action_space_type(spec_list, Box)],
|
||||||
)
|
)
|
||||||
def test_box_actions_out_of_bound(env, seed):
|
def test_box_actions_out_of_bound(env, seed):
|
||||||
"""Test out of bound actions in Box action_space.
|
"""Test out of bound actions in Box action_space.
|
||||||
@@ -80,7 +83,7 @@ def test_box_actions_out_of_bound(env, seed):
|
|||||||
|
|
||||||
env.reset(seed=seed)
|
env.reset(seed=seed)
|
||||||
|
|
||||||
oob_env = envs.make(env.spec.id)
|
oob_env = gym.make(env.spec.id)
|
||||||
oob_env.reset(seed=seed)
|
oob_env.reset(seed=seed)
|
||||||
|
|
||||||
dtype = env.action_space.dtype
|
dtype = env.action_space.dtype
|
||||||
|
Reference in New Issue
Block a user