Gym-Gymnasium compatibility converter (#61)

This commit is contained in:
Ariel Kwiatkowski
2022-10-20 11:30:14 +02:00
committed by GitHub
parent e7c8a8cb59
commit 8b81b7dcc2
9 changed files with 214 additions and 11 deletions

View File

@@ -318,3 +318,11 @@ register(
entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv", entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv",
max_episode_steps=1000, max_episode_steps=1000,
) )
# Gym conversion
# ----------------------------------------
register(
id="GymV26Environment-v0",
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
)

0
gymnasium/envs/external/__init__.py vendored Normal file
View File

159
gymnasium/envs/external/gym_env.py vendored Normal file
View File

@@ -0,0 +1,159 @@
from typing import Optional, Tuple
import gymnasium
from gymnasium import error
from gymnasium.core import ActType, ObsType
try:
import gym
import gym.wrappers
except ImportError as e:
GYM_IMPORT_ERROR = e
else:
GYM_IMPORT_ERROR = None
class GymEnvironment(gymnasium.Env):
"""
Converts a gym environment to a gymnasium environment.
"""
def __init__(
self,
env_id: Optional[str] = None,
make_kwargs: Optional[dict] = None,
env: Optional["gym.Env"] = None,
):
if GYM_IMPORT_ERROR is not None:
raise error.DependencyNotInstalled(
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
)
if make_kwargs is None:
make_kwargs = {}
if env is not None:
self.gym_env = env
elif env_id is not None:
self.gym_env = gym.make(env_id, **make_kwargs)
else:
raise gymnasium.error.MissingArgument(
"Either env_id or env must be provided to create a legacy gym environment."
)
self.gym_env = _strip_default_wrappers(self.gym_env)
self.observation_space = _convert_space(self.gym_env.observation_space)
self.action_space = _convert_space(self.gym_env.action_space)
self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
self.render_mode = self.gym_env.render_mode
self.reward_range = getattr(self.gym_env, "reward_range", None)
self.spec = getattr(self.gym_env, "spec", None)
def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.
Args:
seed: the seed to reset the environment with
options: the options to reset the environment with
Returns:
(observation, info)
"""
super().reset(seed=seed)
# Options are ignored
return self.gym_env.reset(seed=seed, options=options)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info)
"""
return self.gym_env.step(action)
def render(self):
"""Renders the environment.
Returns:
The rendering of the environment, depending on the render mode
"""
return self.gym_env.render()
def close(self):
"""Closes the environment."""
self.gym_env.close()
def __str__(self):
return f"GymEnvironment({self.gym_env})"
def __repr__(self):
return f"GymEnvironment({self.gym_env})"
def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
"""Strips builtin wrappers from the environment.
Args:
env: the environment to strip builtin wrappers from
Returns:
The environment without builtin wrappers
"""
default_wrappers = (
gym.wrappers.render_collection.RenderCollection,
gym.wrappers.human_rendering.HumanRendering,
)
while isinstance(env, default_wrappers):
env = env.env
return env
def _convert_space(space: "gym.Space") -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.
Args:
space: the space to convert
Returns:
The converted space
"""
if isinstance(space, gym.spaces.Discrete):
return gymnasium.spaces.Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
return gymnasium.spaces.Box(
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
)
elif isinstance(space, gym.spaces.MultiDiscrete):
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.MultiBinary):
return gymnasium.spaces.MultiBinary(n=space.n)
elif isinstance(space, gym.spaces.Tuple):
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return gymnasium.spaces.Dict(
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
)
elif isinstance(space, gym.spaces.Sequence):
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
elif isinstance(space, gym.spaces.Graph):
return gymnasium.spaces.Graph(
node_space=_convert_space(space.node_space), # type: ignore
edge_space=_convert_space(space.edge_space), # type: ignore
)
elif isinstance(space, gym.spaces.Text):
return gymnasium.spaces.Text(
max_length=space.max_length,
min_length=space.min_length,
charset=space._char_str,
)
else:
raise NotImplementedError(
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
)

View File

@@ -65,6 +65,10 @@ class InvalidAction(Error):
"""Raised when the user performs an action not contained within the action space.""" """Raised when the user performs an action not contained within the action space."""
class MissingArgument(Error):
"""Raised when a required argument in the initializer is missing."""
# API errors # API errors

View File

@@ -47,7 +47,7 @@ extras = {
testing_group = set(extras.keys()) - {"accept-rom-license", "atari"} testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
extras["testing"] = list( extras["testing"] = list(
set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group))) set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
) + ["pytest==7.0.1"] ) + ["pytest==7.0.1", "gym==0.26.2"]
# All dependency groups - accept rom license as requires user to run # All dependency groups - accept rom license as requires user to run
all_groups = set(extras.keys()) - {"accept-rom-license"} all_groups = set(extras.keys()) - {"accept-rom-license"}

View File

@@ -7,3 +7,4 @@ imageio>=2.14.1
pygame==2.1.0 pygame==2.1.0
mujoco_py<2.2,>=2.1 mujoco_py<2.2,>=2.1
pytest==7.0.1 pytest==7.0.1
gym==0.26.2

View File

@@ -3,12 +3,12 @@ from typing import Any, Dict, Optional, Tuple
import numpy as np import numpy as np
import gymnasium as gym import gymnasium
from gymnasium.spaces import Discrete from gymnasium.spaces import Discrete
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv
class LegacyEnvExplicit(LegacyEnv, gym.Env): class LegacyEnvExplicit(LegacyEnv, gymnasium.Env):
"""Legacy env that explicitly implements the old API.""" """Legacy env that explicitly implements the old API."""
observation_space = Discrete(1) observation_space = Discrete(1)
@@ -37,7 +37,7 @@ class LegacyEnvExplicit(LegacyEnv, gym.Env):
pass pass
class LegacyEnvImplicit(gym.Env): class LegacyEnvImplicit(gymnasium.Env):
"""Legacy env that implicitly implements the old API as a protocol.""" """Legacy env that implicitly implements the old API as a protocol."""
observation_space = Discrete(1) observation_space = Discrete(1)
@@ -95,12 +95,12 @@ def test_implicit():
def test_make_compatibility_in_spec(): def test_make_compatibility_in_spec():
gym.register( gymnasium.register(
id="LegacyTestEnv-v0", id="LegacyTestEnv-v0",
entry_point=LegacyEnvExplicit, entry_point=LegacyEnvExplicit,
apply_api_compatibility=True, apply_api_compatibility=True,
) )
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array") env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array")
assert env.observation_space == Discrete(1) assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1) assert env.action_space == Discrete(1)
assert env.reset() == (0, {}) assert env.reset() == (0, {})
@@ -110,12 +110,12 @@ def test_make_compatibility_in_spec():
assert isinstance(img, np.ndarray) assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore assert img.shape == (1, 1, 3) # type: ignore
env.close() env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"] del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
def test_make_compatibility_in_make(): def test_make_compatibility_in_make():
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit) gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gym.make( env = gymnasium.make(
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array" "LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
) )
assert env.observation_space == Discrete(1) assert env.observation_space == Discrete(1)
@@ -127,4 +127,4 @@ def test_make_compatibility_in_make():
assert isinstance(img, np.ndarray) assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore assert img.shape == (1, 1, 3) # type: ignore
env.close() env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"] del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]

View File

@@ -0,0 +1,27 @@
import pytest
import gymnasium
from gymnasium.utils.env_checker import check_env
pytest.importorskip("gym")
import gym # noqa: E402, isort: skip
ALL_GYM_ENVS = gym.envs.registry.keys()
@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_by_id(env_id):
env = gymnasium.make("GymV26Environment-v0", env_id=env_id)
check_env(env)
@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_instantiated(env_id):
env = gym.make(env_id)
env = gymnasium.make("GymV26Environment-v0", env=env)
check_env(env)

View File

@@ -17,7 +17,11 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
if "gymnasium.envs." in env_spec.entry_point: if "gymnasium.envs." in env_spec.entry_point:
try: try:
return env_spec.make(disable_env_checker=True).unwrapped return env_spec.make(disable_env_checker=True).unwrapped
except (ImportError, gym.error.DependencyNotInstalled) as e: except (
ImportError,
gym.error.DependencyNotInstalled,
gym.error.MissingArgument,
) as e:
logger.warn(f"Not testing {env_spec.id} due to error: {e}") logger.warn(f"Not testing {env_spec.id} due to error: {e}")
return None return None