Gym-Gymnasium compatibility converter (#61)

This commit is contained in:
Ariel Kwiatkowski
2022-10-20 11:30:14 +02:00
committed by GitHub
parent e7c8a8cb59
commit 8b81b7dcc2
9 changed files with 214 additions and 11 deletions

View File

@@ -318,3 +318,11 @@ register(
entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv",
max_episode_steps=1000,
)
# Gym conversion
# ----------------------------------------
register(
id="GymV26Environment-v0",
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
)

0
gymnasium/envs/external/__init__.py vendored Normal file
View File

159
gymnasium/envs/external/gym_env.py vendored Normal file
View File

@@ -0,0 +1,159 @@
from typing import Optional, Tuple
import gymnasium
from gymnasium import error
from gymnasium.core import ActType, ObsType
try:
import gym
import gym.wrappers
except ImportError as e:
GYM_IMPORT_ERROR = e
else:
GYM_IMPORT_ERROR = None
class GymEnvironment(gymnasium.Env):
"""
Converts a gym environment to a gymnasium environment.
"""
def __init__(
self,
env_id: Optional[str] = None,
make_kwargs: Optional[dict] = None,
env: Optional["gym.Env"] = None,
):
if GYM_IMPORT_ERROR is not None:
raise error.DependencyNotInstalled(
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
)
if make_kwargs is None:
make_kwargs = {}
if env is not None:
self.gym_env = env
elif env_id is not None:
self.gym_env = gym.make(env_id, **make_kwargs)
else:
raise gymnasium.error.MissingArgument(
"Either env_id or env must be provided to create a legacy gym environment."
)
self.gym_env = _strip_default_wrappers(self.gym_env)
self.observation_space = _convert_space(self.gym_env.observation_space)
self.action_space = _convert_space(self.gym_env.action_space)
self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
self.render_mode = self.gym_env.render_mode
self.reward_range = getattr(self.gym_env, "reward_range", None)
self.spec = getattr(self.gym_env, "spec", None)
def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.
Args:
seed: the seed to reset the environment with
options: the options to reset the environment with
Returns:
(observation, info)
"""
super().reset(seed=seed)
# Options are ignored
return self.gym_env.reset(seed=seed, options=options)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info)
"""
return self.gym_env.step(action)
def render(self):
"""Renders the environment.
Returns:
The rendering of the environment, depending on the render mode
"""
return self.gym_env.render()
def close(self):
"""Closes the environment."""
self.gym_env.close()
def __str__(self):
return f"GymEnvironment({self.gym_env})"
def __repr__(self):
return f"GymEnvironment({self.gym_env})"
def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
"""Strips builtin wrappers from the environment.
Args:
env: the environment to strip builtin wrappers from
Returns:
The environment without builtin wrappers
"""
default_wrappers = (
gym.wrappers.render_collection.RenderCollection,
gym.wrappers.human_rendering.HumanRendering,
)
while isinstance(env, default_wrappers):
env = env.env
return env
def _convert_space(space: "gym.Space") -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.
Args:
space: the space to convert
Returns:
The converted space
"""
if isinstance(space, gym.spaces.Discrete):
return gymnasium.spaces.Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
return gymnasium.spaces.Box(
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
)
elif isinstance(space, gym.spaces.MultiDiscrete):
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.MultiBinary):
return gymnasium.spaces.MultiBinary(n=space.n)
elif isinstance(space, gym.spaces.Tuple):
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return gymnasium.spaces.Dict(
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
)
elif isinstance(space, gym.spaces.Sequence):
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
elif isinstance(space, gym.spaces.Graph):
return gymnasium.spaces.Graph(
node_space=_convert_space(space.node_space), # type: ignore
edge_space=_convert_space(space.edge_space), # type: ignore
)
elif isinstance(space, gym.spaces.Text):
return gymnasium.spaces.Text(
max_length=space.max_length,
min_length=space.min_length,
charset=space._char_str,
)
else:
raise NotImplementedError(
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
)

View File

@@ -65,6 +65,10 @@ class InvalidAction(Error):
"""Raised when the user performs an action not contained within the action space."""
class MissingArgument(Error):
"""Raised when a required argument in the initializer is missing."""
# API errors

View File

@@ -47,7 +47,7 @@ extras = {
testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
extras["testing"] = list(
set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
) + ["pytest==7.0.1"]
) + ["pytest==7.0.1", "gym==0.26.2"]
# All dependency groups - accept rom license as requires user to run
all_groups = set(extras.keys()) - {"accept-rom-license"}

View File

@@ -7,3 +7,4 @@ imageio>=2.14.1
pygame==2.1.0
mujoco_py<2.2,>=2.1
pytest==7.0.1
gym==0.26.2

View File

@@ -3,12 +3,12 @@ from typing import Any, Dict, Optional, Tuple
import numpy as np
import gymnasium as gym
import gymnasium
from gymnasium.spaces import Discrete
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv
class LegacyEnvExplicit(LegacyEnv, gym.Env):
class LegacyEnvExplicit(LegacyEnv, gymnasium.Env):
"""Legacy env that explicitly implements the old API."""
observation_space = Discrete(1)
@@ -37,7 +37,7 @@ class LegacyEnvExplicit(LegacyEnv, gym.Env):
pass
class LegacyEnvImplicit(gym.Env):
class LegacyEnvImplicit(gymnasium.Env):
"""Legacy env that implicitly implements the old API as a protocol."""
observation_space = Discrete(1)
@@ -95,12 +95,12 @@ def test_implicit():
def test_make_compatibility_in_spec():
gym.register(
gymnasium.register(
id="LegacyTestEnv-v0",
entry_point=LegacyEnvExplicit,
apply_api_compatibility=True,
)
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array")
env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
@@ -110,12 +110,12 @@ def test_make_compatibility_in_spec():
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
def test_make_compatibility_in_make():
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gym.make(
gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gymnasium.make(
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
)
assert env.observation_space == Discrete(1)
@@ -127,4 +127,4 @@ def test_make_compatibility_in_make():
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]

View File

@@ -0,0 +1,27 @@
import pytest
import gymnasium
from gymnasium.utils.env_checker import check_env
pytest.importorskip("gym")
import gym # noqa: E402, isort: skip
ALL_GYM_ENVS = gym.envs.registry.keys()
@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_by_id(env_id):
env = gymnasium.make("GymV26Environment-v0", env_id=env_id)
check_env(env)
@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_instantiated(env_id):
env = gym.make(env_id)
env = gymnasium.make("GymV26Environment-v0", env=env)
check_env(env)

View File

@@ -17,7 +17,11 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
if "gymnasium.envs." in env_spec.entry_point:
try:
return env_spec.make(disable_env_checker=True).unwrapped
except (ImportError, gym.error.DependencyNotInstalled) as e:
except (
ImportError,
gym.error.DependencyNotInstalled,
gym.error.MissingArgument,
) as e:
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
return None