Rename to gymnasium

This commit is contained in:
pseudo-rnd-thoughts
2022-09-08 10:11:31 +01:00
parent 640c509109
commit 47ba48b611
25 changed files with 121 additions and 34 deletions

View File

@@ -21,7 +21,9 @@ try:
revoluteJointDef, revoluteJointDef,
) )
except ImportError: except ImportError:
raise DependencyNotInstalled("box2D is not installed, run `pip install gymnasium[box2d]`") raise DependencyNotInstalled(
"box2D is not installed, run `pip install gymnasium[box2d]`"
)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -17,7 +17,9 @@ from gymnasium.error import DependencyNotInstalled
try: try:
from Box2D.b2 import fixtureDef, polygonShape, revoluteJointDef from Box2D.b2 import fixtureDef, polygonShape, revoluteJointDef
except ImportError: except ImportError:
raise DependencyNotInstalled("box2D is not installed, run `pip install gymnasium[box2d]`") raise DependencyNotInstalled(
"box2D is not installed, run `pip install gymnasium[box2d]`"
)
SIZE = 0.02 SIZE = 0.02

View File

@@ -15,7 +15,9 @@ try:
import Box2D import Box2D
from Box2D.b2 import contactListener, fixtureDef, polygonShape from Box2D.b2 import contactListener, fixtureDef, polygonShape
except ImportError: except ImportError:
raise DependencyNotInstalled("box2D is not installed, run `pip install gymnasium[box2d]`") raise DependencyNotInstalled(
"box2D is not installed, run `pip install gymnasium[box2d]`"
)
try: try:
# As pygame is necessary for using the environment (reset and step) even without a render mode # As pygame is necessary for using the environment (reset and step) even without a render mode

View File

@@ -23,7 +23,9 @@ try:
revoluteJointDef, revoluteJointDef,
) )
except ImportError: except ImportError:
raise DependencyNotInstalled("box2d is not installed, run `pip install gymnasium[box2d]`") raise DependencyNotInstalled(
"box2d is not installed, run `pip install gymnasium[box2d]`"
)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -1,5 +1,7 @@
from gymnasium.envs.classic_control.acrobot import AcrobotEnv from gymnasium.envs.classic_control.acrobot import AcrobotEnv
from gymnasium.envs.classic_control.cartpole import CartPoleEnv from gymnasium.envs.classic_control.cartpole import CartPoleEnv
from gymnasium.envs.classic_control.continuous_mountain_car import Continuous_MountainCarEnv from gymnasium.envs.classic_control.continuous_mountain_car import (
Continuous_MountainCarEnv,
)
from gymnasium.envs.classic_control.mountain_car import MountainCarEnv from gymnasium.envs.classic_control.mountain_car import MountainCarEnv
from gymnasium.envs.classic_control.pendulum import PendulumEnv from gymnasium.envs.classic_control.pendulum import PendulumEnv

View File

@@ -395,7 +395,9 @@ class MujocoEnv(BaseMujocoEnv):
def _get_viewer( def _get_viewer(
self, mode self, mode
) -> Union["gymnasium.envs.mujoco.Viewer", "gymnasium.envs.mujoco.RenderContextOffscreen"]: ) -> Union[
"gymnasium.envs.mujoco.Viewer", "gymnasium.envs.mujoco.RenderContextOffscreen"
]:
self.viewer = self._viewers.get(mode) self.viewer = self._viewers.get(mode)
if self.viewer is None: if self.viewer is None:
if mode == "human": if mode == "human":

View File

@@ -6,8 +6,12 @@ from gymnasium.vector.utils.shared_memory import (
read_from_shared_memory, read_from_shared_memory,
write_to_shared_memory, write_to_shared_memory,
) )
from gymnasium.vector.utils.spaces import _BaseGymSpaces # pyright: reportPrivateUsage=false from gymnasium.vector.utils.spaces import ( # pyright: reportPrivateUsage=false
from gymnasium.vector.utils.spaces import BaseGymSpaces, batch_space, iterate BaseGymSpaces,
_BaseGymSpaces,
batch_space,
iterate,
)
__all__ = [ __all__ = [
"CloudpickleWrapper", "CloudpickleWrapper",

View File

@@ -5,7 +5,15 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["concatenate", "create_empty_array"] __all__ = ["concatenate", "create_empty_array"]

View File

@@ -8,7 +8,15 @@ from typing import Union
import numpy as np import numpy as np
from gymnasium.error import CustomSpaceError from gymnasium.error import CustomSpaceError
from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"] __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]

View File

@@ -7,7 +7,15 @@ from typing import Iterator
import numpy as np import numpy as np
from gymnasium.error import CustomSpaceError from gymnasium.error import CustomSpaceError
from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary) BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
_BaseGymSpaces = BaseGymSpaces _BaseGymSpaces = BaseGymSpaces

View File

@@ -4,7 +4,9 @@ from typing import Any, Dict, Optional, Tuple
import gymnasium import gymnasium
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium.utils.step_api_compatibility import convert_to_terminated_truncated_step_api from gymnasium.utils.step_api_compatibility import (
convert_to_terminated_truncated_step_api,
)
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable

View File

@@ -19,7 +19,9 @@ class OrderEnforcing(gymnasium.Wrapper):
>>> env.step(0) >>> env.step(0)
""" """
def __init__(self, env: gymnasium.Env, disable_render_order_enforcing: bool = False): def __init__(
self, env: gymnasium.Env, disable_render_order_enforcing: bool = False
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args: Args:

View File

@@ -5,7 +5,9 @@ import gymnasium
class RenderCollection(gymnasium.Wrapper): class RenderCollection(gymnasium.Wrapper):
"""Save collection of render frames.""" """Save collection of render frames."""
def __init__(self, env: gymnasium.Env, pop_frames: bool = True, reset_clean: bool = True): def __init__(
self, env: gymnasium.Env, pop_frames: bool = True, reset_clean: bool = True
):
"""Initialize a :class:`RenderCollection` instance. """Initialize a :class:`RenderCollection` instance.
Args: Args:

View File

@@ -72,7 +72,9 @@ setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
name="gymnasium", name="gymnasium",
packages=[package for package in find_packages() if package.startswith("gymnasium")], packages=[
package for package in find_packages() if package.startswith("gymnasium")
],
package_data={ package_data={
"gymnasium": [ "gymnasium": [
"envs/mujoco/assets/*.xml", "envs/mujoco/assets/*.xml",

View File

@@ -72,7 +72,9 @@ def test_bipedal_walker_hardcore_creation(seed: int):
HC_TERRAINS_COLOR2 = (153, 153, 153) HC_TERRAINS_COLOR2 = (153, 153, 153)
env = gymnasium.make("BipedalWalker-v3", disable_env_checker=True).unwrapped env = gymnasium.make("BipedalWalker-v3", disable_env_checker=True).unwrapped
hc_env = gymnasium.make("BipedalWalkerHardcore-v3", disable_env_checker=True).unwrapped hc_env = gymnasium.make(
"BipedalWalkerHardcore-v3", disable_env_checker=True
).unwrapped
assert isinstance(env, BipedalWalker) and isinstance(hc_env, BipedalWalker) assert isinstance(env, BipedalWalker) and isinstance(hc_env, BipedalWalker)
assert env.hardcore is False and hc_env.hardcore is True assert env.hardcore is False and hc_env.hardcore is True

View File

@@ -9,7 +9,12 @@ import pytest
import gymnasium import gymnasium
from gymnasium.envs.classic_control import cartpole from gymnasium.envs.classic_control import cartpole
from gymnasium.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit from gymnasium.wrappers import (
AutoResetWrapper,
HumanRendering,
OrderEnforcing,
TimeLimit,
)
from gymnasium.wrappers.env_checker import PassiveEnvChecker from gymnasium.wrappers.env_checker import PassiveEnvChecker
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.utils import all_testing_env_specs from tests.envs.utils import all_testing_env_specs
@@ -69,7 +74,8 @@ def test_make_max_episode_steps():
env = gymnasium.make("CartPole-v1", disable_env_checker=True) env = gymnasium.make("CartPole-v1", disable_env_checker=True)
assert has_wrapper(env, TimeLimit) assert has_wrapper(env, TimeLimit)
assert ( assert (
env.spec.max_episode_steps == gymnasium.envs.registry["CartPole-v1"].max_episode_steps env.spec.max_episode_steps
== gymnasium.envs.registry["CartPole-v1"].max_episode_steps
) )
env.close() env.close()

View File

@@ -86,7 +86,9 @@ def test_register(
], ],
) )
def test_register_error(env_id): def test_register_error(env_id):
with pytest.raises(gymnasium.error.Error, match=f"^Malformed environment ID: {env_id}"): with pytest.raises(
gymnasium.error.Error, match=f"^Malformed environment ID: {env_id}"
):
gymnasium.register(env_id, "no-entry-point") gymnasium.register(env_id, "no-entry-point")

View File

@@ -89,19 +89,25 @@ def test_bad_space_calls(space_fn):
def test_contains_promotion(): def test_contains_promotion():
space = gymnasium.spaces.Tuple((gymnasium.spaces.Box(0, 1), gymnasium.spaces.Box(-1, 0, (2,)))) space = gymnasium.spaces.Tuple(
(gymnasium.spaces.Box(0, 1), gymnasium.spaces.Box(-1, 0, (2,)))
)
assert ( assert (
np.array([0.0], dtype=np.float32), np.array([0.0], dtype=np.float32),
np.array([0.0, 0.0], dtype=np.float32), np.array([0.0, 0.0], dtype=np.float32),
) in space ) in space
space = gymnasium.spaces.Tuple((gymnasium.spaces.Box(0, 1), gymnasium.spaces.Box(-1, 0, (1,)))) space = gymnasium.spaces.Tuple(
(gymnasium.spaces.Box(0, 1), gymnasium.spaces.Box(-1, 0, (1,)))
)
assert np.array([[0.0], [0.0]], dtype=np.float32) in space assert np.array([[0.0], [0.0]], dtype=np.float32) in space
def test_bad_seed(): def test_bad_seed():
space = gymnasium.spaces.Tuple((gymnasium.spaces.Box(0, 1), gymnasium.spaces.Box(0, 1))) space = gymnasium.spaces.Tuple(
(gymnasium.spaces.Box(0, 1), gymnasium.spaces.Box(0, 1))
)
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>", match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",

View File

@@ -92,7 +92,8 @@ def test_flatten_space(space):
assert edge_single_dim == edge_flatdim assert edge_single_dim == edge_flatdim
else: else:
assert isinstance( assert isinstance(
space, (gymnasium.spaces.Tuple, gymnasium.spaces.Dict, gymnasium.spaces.Sequence) space,
(gymnasium.spaces.Tuple, gymnasium.spaces.Dict, gymnasium.spaces.Sequence),
) )

View File

@@ -188,7 +188,9 @@ def test_play_loop_real_env():
action = keys_to_action[chr(e.key) if str_keys else (e.key,)] action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
obs, _, _, _, _ = env.step(action) obs, _, _, _, _ = env.step(action)
env_play = gymnasium.make(ENV, render_mode="rgb_array", disable_env_checker=True) env_play = gymnasium.make(
ENV, render_mode="rgb_array", disable_env_checker=True
)
if apply_wrapper: if apply_wrapper:
env_play = KeysToActionWrapper(env, keys_to_action=keys_to_action) env_play = KeysToActionWrapper(env, keys_to_action=keys_to_action)
assert hasattr(env_play, "get_keys_to_action") assert hasattr(env_play, "get_keys_to_action")

View File

@@ -4,7 +4,11 @@ from multiprocessing import TimeoutError
import numpy as np import numpy as np
import pytest import pytest
from gymnasium.error import AlreadyPendingCallError, ClosedEnvironmentError, NoAsyncCallError from gymnasium.error import (
AlreadyPendingCallError,
ClosedEnvironmentError,
NoAsyncCallError,
)
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
from gymnasium.vector.async_vector_env import AsyncVectorEnv from gymnasium.vector.async_vector_env import AsyncVectorEnv
from tests.vector.utils import ( from tests.vector.utils import (

View File

@@ -26,8 +26,12 @@ def test_record_video_using_default_trigger():
def test_record_video_reset(): def test_record_video_reset():
env = gymnasium.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) env = gymnasium.make(
env = gymnasium.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) "CartPole-v1", render_mode="rgb_array", disable_env_checker=True
)
env = gymnasium.wrappers.RecordVideo(
env, "videos", step_trigger=lambda x: x % 100 == 0
)
ob_space = env.observation_space ob_space = env.observation_space
obs, info = env.reset() obs, info = env.reset()
env.close() env.close()
@@ -38,9 +42,13 @@ def test_record_video_reset():
def test_record_video_step_trigger(): def test_record_video_step_trigger():
env = gymnasium.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) env = gymnasium.make(
"CartPole-v1", render_mode="rgb_array", disable_env_checker=True
)
env._max_episode_steps = 20 env._max_episode_steps = 20
env = gymnasium.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env = gymnasium.wrappers.RecordVideo(
env, "videos", step_trigger=lambda x: x % 100 == 0
)
env.reset() env.reset()
for _ in range(199): for _ in range(199):
action = env.action_space.sample() action = env.action_space.sample()

View File

@@ -12,7 +12,8 @@ def test_transform_observation(env_id):
env = gymnasium.make(env_id, disable_env_checker=True) env = gymnasium.make(env_id, disable_env_checker=True)
wrapped_env = TransformObservation( wrapped_env = TransformObservation(
gymnasium.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs) gymnasium.make(env_id, disable_env_checker=True),
lambda obs: affine_transform(obs),
) )
obs, info = env.reset(seed=0) obs, info = env.reset(seed=0)

View File

@@ -30,7 +30,8 @@ def test_transform_reward(env_id):
max_r = 0.0002 max_r = 0.0002
env = gymnasium.make(env_id, disable_env_checker=True) env = gymnasium.make(env_id, disable_env_checker=True)
wrapped_env = TransformReward( wrapped_env = TransformReward(
gymnasium.make(env_id, disable_env_checker=True), lambda r: np.clip(r, min_r, max_r) gymnasium.make(env_id, disable_env_checker=True),
lambda r: np.clip(r, min_r, max_r),
) )
action = env.action_space.sample() action = env.action_space.sample()

View File

@@ -11,7 +11,9 @@ SEED = 42
def test_usage_in_vector_env(): def test_usage_in_vector_env():
env = gymnasium.make(ENV_ID, disable_env_checker=True) env = gymnasium.make(ENV_ID, disable_env_checker=True)
vector_env = gymnasium.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) vector_env = gymnasium.vector.make(
ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True
)
VectorListInfo(vector_env) VectorListInfo(vector_env)
@@ -20,7 +22,9 @@ def test_usage_in_vector_env():
def test_info_to_list(): def test_info_to_list():
env_to_wrap = gymnasium.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) env_to_wrap = gymnasium.vector.make(
ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True
)
wrapped_env = VectorListInfo(env_to_wrap) wrapped_env = VectorListInfo(env_to_wrap)
wrapped_env.action_space.seed(SEED) wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED) _, info = wrapped_env.reset(seed=SEED)
@@ -38,7 +42,9 @@ def test_info_to_list():
def test_info_to_list_statistics(): def test_info_to_list_statistics():
env_to_wrap = gymnasium.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) env_to_wrap = gymnasium.vector.make(
ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True
)
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap)) wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED) _, info = wrapped_env.reset(seed=SEED)
wrapped_env.action_space.seed(SEED) wrapped_env.action_space.seed(SEED)