From 8b81b7dcc2ec3a4d29526083b14657c1a838b23e Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 20 Oct 2022 11:30:14 +0200 Subject: [PATCH] Gym-Gymnasium compatibility converter (#61) --- gymnasium/envs/__init__.py | 8 ++ gymnasium/envs/external/__init__.py | 0 gymnasium/envs/external/gym_env.py | 159 ++++++++++++++++++++++++++++ gymnasium/error.py | 4 + setup.py | 2 +- test_requirements.txt | 1 + tests/envs/test_compatibility.py | 18 ++-- tests/envs/test_gym_conversion.py | 27 +++++ tests/envs/utils.py | 6 +- 9 files changed, 214 insertions(+), 11 deletions(-) create mode 100644 gymnasium/envs/external/__init__.py create mode 100644 gymnasium/envs/external/gym_env.py create mode 100644 tests/envs/test_gym_conversion.py diff --git a/gymnasium/envs/__init__.py b/gymnasium/envs/__init__.py index a7b408718..64392b95b 100644 --- a/gymnasium/envs/__init__.py +++ b/gymnasium/envs/__init__.py @@ -318,3 +318,11 @@ register( entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv", max_episode_steps=1000, ) + + +# Gym conversion +# ---------------------------------------- +register( + id="GymV26Environment-v0", + entry_point="gymnasium.envs.external.gym_env:GymEnvironment", +) diff --git a/gymnasium/envs/external/__init__.py b/gymnasium/envs/external/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gymnasium/envs/external/gym_env.py b/gymnasium/envs/external/gym_env.py new file mode 100644 index 000000000..db6494320 --- /dev/null +++ b/gymnasium/envs/external/gym_env.py @@ -0,0 +1,159 @@ +from typing import Optional, Tuple + +import gymnasium +from gymnasium import error +from gymnasium.core import ActType, ObsType + +try: + import gym + import gym.wrappers +except ImportError as e: + GYM_IMPORT_ERROR = e +else: + GYM_IMPORT_ERROR = None + + +class GymEnvironment(gymnasium.Env): + """ + Converts a gym environment to a gymnasium environment. + """ + + def __init__( + self, + env_id: Optional[str] = None, + make_kwargs: Optional[dict] = None, + env: Optional["gym.Env"] = None, + ): + if GYM_IMPORT_ERROR is not None: + raise error.DependencyNotInstalled( + f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments" + ) + + if make_kwargs is None: + make_kwargs = {} + + if env is not None: + self.gym_env = env + elif env_id is not None: + self.gym_env = gym.make(env_id, **make_kwargs) + else: + raise gymnasium.error.MissingArgument( + "Either env_id or env must be provided to create a legacy gym environment." + ) + self.gym_env = _strip_default_wrappers(self.gym_env) + + self.observation_space = _convert_space(self.gym_env.observation_space) + self.action_space = _convert_space(self.gym_env.action_space) + + self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []}) + self.render_mode = self.gym_env.render_mode + self.reward_range = getattr(self.gym_env, "reward_range", None) + self.spec = getattr(self.gym_env, "spec", None) + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[ObsType, dict]: + """Resets the environment. + + Args: + seed: the seed to reset the environment with + options: the options to reset the environment with + + Returns: + (observation, info) + """ + super().reset(seed=seed) + # Options are ignored + return self.gym_env.reset(seed=seed, options=options) + + def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: + """Steps through the environment. + + Args: + action: action to step through the environment with + + Returns: + (observation, reward, terminated, truncated, info) + """ + return self.gym_env.step(action) + + def render(self): + """Renders the environment. + + Returns: + The rendering of the environment, depending on the render mode + """ + return self.gym_env.render() + + def close(self): + """Closes the environment.""" + self.gym_env.close() + + def __str__(self): + return f"GymEnvironment({self.gym_env})" + + def __repr__(self): + return f"GymEnvironment({self.gym_env})" + + +def _strip_default_wrappers(env: "gym.Env") -> "gym.Env": + """Strips builtin wrappers from the environment. + + Args: + env: the environment to strip builtin wrappers from + + Returns: + The environment without builtin wrappers + """ + + default_wrappers = ( + gym.wrappers.render_collection.RenderCollection, + gym.wrappers.human_rendering.HumanRendering, + ) + while isinstance(env, default_wrappers): + env = env.env + return env + + +def _convert_space(space: "gym.Space") -> gymnasium.Space: + """Converts a gym space to a gymnasium space. + + Args: + space: the space to convert + + Returns: + The converted space + """ + if isinstance(space, gym.spaces.Discrete): + return gymnasium.spaces.Discrete(n=space.n) + elif isinstance(space, gym.spaces.Box): + return gymnasium.spaces.Box( + low=space.low, high=space.high, shape=space.shape, dtype=space.dtype + ) + elif isinstance(space, gym.spaces.MultiDiscrete): + return gymnasium.spaces.MultiDiscrete(nvec=space.nvec) + elif isinstance(space, gym.spaces.MultiBinary): + return gymnasium.spaces.MultiBinary(n=space.n) + elif isinstance(space, gym.spaces.Tuple): + return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces))) + elif isinstance(space, gym.spaces.Dict): + return gymnasium.spaces.Dict( + spaces={k: _convert_space(v) for k, v in space.spaces.items()} + ) + elif isinstance(space, gym.spaces.Sequence): + return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space)) + elif isinstance(space, gym.spaces.Graph): + return gymnasium.spaces.Graph( + node_space=_convert_space(space.node_space), # type: ignore + edge_space=_convert_space(space.edge_space), # type: ignore + ) + elif isinstance(space, gym.spaces.Text): + return gymnasium.spaces.Text( + max_length=space.max_length, + min_length=space.min_length, + charset=space._char_str, + ) + else: + raise NotImplementedError( + f"Cannot convert space of type {space}. Please upgrade your code to gymnasium." + ) diff --git a/gymnasium/error.py b/gymnasium/error.py index e80282ed8..57730b85b 100644 --- a/gymnasium/error.py +++ b/gymnasium/error.py @@ -65,6 +65,10 @@ class InvalidAction(Error): """Raised when the user performs an action not contained within the action space.""" +class MissingArgument(Error): + """Raised when a required argument in the initializer is missing.""" + + # API errors diff --git a/setup.py b/setup.py index 7ded89c31..9e044ec12 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ extras = { testing_group = set(extras.keys()) - {"accept-rom-license", "atari"} extras["testing"] = list( set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group))) -) + ["pytest==7.0.1"] +) + ["pytest==7.0.1", "gym==0.26.2"] # All dependency groups - accept rom license as requires user to run all_groups = set(extras.keys()) - {"accept-rom-license"} diff --git a/test_requirements.txt b/test_requirements.txt index 036602735..0a488c30d 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -7,3 +7,4 @@ imageio>=2.14.1 pygame==2.1.0 mujoco_py<2.2,>=2.1 pytest==7.0.1 +gym==0.26.2 diff --git a/tests/envs/test_compatibility.py b/tests/envs/test_compatibility.py index c15e01c37..d651579ac 100644 --- a/tests/envs/test_compatibility.py +++ b/tests/envs/test_compatibility.py @@ -3,12 +3,12 @@ from typing import Any, Dict, Optional, Tuple import numpy as np -import gymnasium as gym +import gymnasium from gymnasium.spaces import Discrete from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv -class LegacyEnvExplicit(LegacyEnv, gym.Env): +class LegacyEnvExplicit(LegacyEnv, gymnasium.Env): """Legacy env that explicitly implements the old API.""" observation_space = Discrete(1) @@ -37,7 +37,7 @@ class LegacyEnvExplicit(LegacyEnv, gym.Env): pass -class LegacyEnvImplicit(gym.Env): +class LegacyEnvImplicit(gymnasium.Env): """Legacy env that implicitly implements the old API as a protocol.""" observation_space = Discrete(1) @@ -95,12 +95,12 @@ def test_implicit(): def test_make_compatibility_in_spec(): - gym.register( + gymnasium.register( id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit, apply_api_compatibility=True, ) - env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array") + env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array") assert env.observation_space == Discrete(1) assert env.action_space == Discrete(1) assert env.reset() == (0, {}) @@ -110,12 +110,12 @@ def test_make_compatibility_in_spec(): assert isinstance(img, np.ndarray) assert img.shape == (1, 1, 3) # type: ignore env.close() - del gym.envs.registration.registry["LegacyTestEnv-v0"] + del gymnasium.envs.registration.registry["LegacyTestEnv-v0"] def test_make_compatibility_in_make(): - gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit) - env = gym.make( + gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit) + env = gymnasium.make( "LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array" ) assert env.observation_space == Discrete(1) @@ -127,4 +127,4 @@ def test_make_compatibility_in_make(): assert isinstance(img, np.ndarray) assert img.shape == (1, 1, 3) # type: ignore env.close() - del gym.envs.registration.registry["LegacyTestEnv-v0"] + del gymnasium.envs.registration.registry["LegacyTestEnv-v0"] diff --git a/tests/envs/test_gym_conversion.py b/tests/envs/test_gym_conversion.py new file mode 100644 index 000000000..8f97d41cc --- /dev/null +++ b/tests/envs/test_gym_conversion.py @@ -0,0 +1,27 @@ +import pytest + +import gymnasium +from gymnasium.utils.env_checker import check_env + +pytest.importorskip("gym") + +import gym # noqa: E402, isort: skip + +ALL_GYM_ENVS = gym.envs.registry.keys() + + +@pytest.mark.parametrize( + "env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS] +) +def test_gym_conversion_by_id(env_id): + env = gymnasium.make("GymV26Environment-v0", env_id=env_id) + check_env(env) + + +@pytest.mark.parametrize( + "env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS] +) +def test_gym_conversion_instantiated(env_id): + env = gym.make(env_id) + env = gymnasium.make("GymV26Environment-v0", env=env) + check_env(env) diff --git a/tests/envs/utils.py b/tests/envs/utils.py index 36f60a1a1..254719fc1 100644 --- a/tests/envs/utils.py +++ b/tests/envs/utils.py @@ -17,7 +17,11 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]: if "gymnasium.envs." in env_spec.entry_point: try: return env_spec.make(disable_env_checker=True).unwrapped - except (ImportError, gym.error.DependencyNotInstalled) as e: + except ( + ImportError, + gym.error.DependencyNotInstalled, + gym.error.MissingArgument, + ) as e: logger.warn(f"Not testing {env_spec.id} due to error: {e}") return None