mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Gym-Gymnasium compatibility converter (#61)
This commit is contained in:
committed by
GitHub
parent
e7c8a8cb59
commit
8b81b7dcc2
@@ -318,3 +318,11 @@ register(
|
|||||||
entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv",
|
entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv",
|
||||||
max_episode_steps=1000,
|
max_episode_steps=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Gym conversion
|
||||||
|
# ----------------------------------------
|
||||||
|
register(
|
||||||
|
id="GymV26Environment-v0",
|
||||||
|
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
|
||||||
|
)
|
||||||
|
0
gymnasium/envs/external/__init__.py
vendored
Normal file
0
gymnasium/envs/external/__init__.py
vendored
Normal file
159
gymnasium/envs/external/gym_env.py
vendored
Normal file
159
gymnasium/envs/external/gym_env.py
vendored
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
from gymnasium import error
|
||||||
|
from gymnasium.core import ActType, ObsType
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gym
|
||||||
|
import gym.wrappers
|
||||||
|
except ImportError as e:
|
||||||
|
GYM_IMPORT_ERROR = e
|
||||||
|
else:
|
||||||
|
GYM_IMPORT_ERROR = None
|
||||||
|
|
||||||
|
|
||||||
|
class GymEnvironment(gymnasium.Env):
|
||||||
|
"""
|
||||||
|
Converts a gym environment to a gymnasium environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_id: Optional[str] = None,
|
||||||
|
make_kwargs: Optional[dict] = None,
|
||||||
|
env: Optional["gym.Env"] = None,
|
||||||
|
):
|
||||||
|
if GYM_IMPORT_ERROR is not None:
|
||||||
|
raise error.DependencyNotInstalled(
|
||||||
|
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
|
||||||
|
)
|
||||||
|
|
||||||
|
if make_kwargs is None:
|
||||||
|
make_kwargs = {}
|
||||||
|
|
||||||
|
if env is not None:
|
||||||
|
self.gym_env = env
|
||||||
|
elif env_id is not None:
|
||||||
|
self.gym_env = gym.make(env_id, **make_kwargs)
|
||||||
|
else:
|
||||||
|
raise gymnasium.error.MissingArgument(
|
||||||
|
"Either env_id or env must be provided to create a legacy gym environment."
|
||||||
|
)
|
||||||
|
self.gym_env = _strip_default_wrappers(self.gym_env)
|
||||||
|
|
||||||
|
self.observation_space = _convert_space(self.gym_env.observation_space)
|
||||||
|
self.action_space = _convert_space(self.gym_env.action_space)
|
||||||
|
|
||||||
|
self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
|
||||||
|
self.render_mode = self.gym_env.render_mode
|
||||||
|
self.reward_range = getattr(self.gym_env, "reward_range", None)
|
||||||
|
self.spec = getattr(self.gym_env, "spec", None)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, seed: Optional[int] = None, options: Optional[dict] = None
|
||||||
|
) -> Tuple[ObsType, dict]:
|
||||||
|
"""Resets the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: the seed to reset the environment with
|
||||||
|
options: the options to reset the environment with
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(observation, info)
|
||||||
|
"""
|
||||||
|
super().reset(seed=seed)
|
||||||
|
# Options are ignored
|
||||||
|
return self.gym_env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||||
|
"""Steps through the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: action to step through the environment with
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(observation, reward, terminated, truncated, info)
|
||||||
|
"""
|
||||||
|
return self.gym_env.step(action)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
"""Renders the environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The rendering of the environment, depending on the render mode
|
||||||
|
"""
|
||||||
|
return self.gym_env.render()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Closes the environment."""
|
||||||
|
self.gym_env.close()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"GymEnvironment({self.gym_env})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"GymEnvironment({self.gym_env})"
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
|
||||||
|
"""Strips builtin wrappers from the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: the environment to strip builtin wrappers from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The environment without builtin wrappers
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_wrappers = (
|
||||||
|
gym.wrappers.render_collection.RenderCollection,
|
||||||
|
gym.wrappers.human_rendering.HumanRendering,
|
||||||
|
)
|
||||||
|
while isinstance(env, default_wrappers):
|
||||||
|
env = env.env
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_space(space: "gym.Space") -> gymnasium.Space:
|
||||||
|
"""Converts a gym space to a gymnasium space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: the space to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted space
|
||||||
|
"""
|
||||||
|
if isinstance(space, gym.spaces.Discrete):
|
||||||
|
return gymnasium.spaces.Discrete(n=space.n)
|
||||||
|
elif isinstance(space, gym.spaces.Box):
|
||||||
|
return gymnasium.spaces.Box(
|
||||||
|
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
|
||||||
|
)
|
||||||
|
elif isinstance(space, gym.spaces.MultiDiscrete):
|
||||||
|
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
|
||||||
|
elif isinstance(space, gym.spaces.MultiBinary):
|
||||||
|
return gymnasium.spaces.MultiBinary(n=space.n)
|
||||||
|
elif isinstance(space, gym.spaces.Tuple):
|
||||||
|
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
|
||||||
|
elif isinstance(space, gym.spaces.Dict):
|
||||||
|
return gymnasium.spaces.Dict(
|
||||||
|
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
|
||||||
|
)
|
||||||
|
elif isinstance(space, gym.spaces.Sequence):
|
||||||
|
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
|
||||||
|
elif isinstance(space, gym.spaces.Graph):
|
||||||
|
return gymnasium.spaces.Graph(
|
||||||
|
node_space=_convert_space(space.node_space), # type: ignore
|
||||||
|
edge_space=_convert_space(space.edge_space), # type: ignore
|
||||||
|
)
|
||||||
|
elif isinstance(space, gym.spaces.Text):
|
||||||
|
return gymnasium.spaces.Text(
|
||||||
|
max_length=space.max_length,
|
||||||
|
min_length=space.min_length,
|
||||||
|
charset=space._char_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
|
||||||
|
)
|
@@ -65,6 +65,10 @@ class InvalidAction(Error):
|
|||||||
"""Raised when the user performs an action not contained within the action space."""
|
"""Raised when the user performs an action not contained within the action space."""
|
||||||
|
|
||||||
|
|
||||||
|
class MissingArgument(Error):
|
||||||
|
"""Raised when a required argument in the initializer is missing."""
|
||||||
|
|
||||||
|
|
||||||
# API errors
|
# API errors
|
||||||
|
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@@ -47,7 +47,7 @@ extras = {
|
|||||||
testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
|
testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
|
||||||
extras["testing"] = list(
|
extras["testing"] = list(
|
||||||
set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
|
set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
|
||||||
) + ["pytest==7.0.1"]
|
) + ["pytest==7.0.1", "gym==0.26.2"]
|
||||||
|
|
||||||
# All dependency groups - accept rom license as requires user to run
|
# All dependency groups - accept rom license as requires user to run
|
||||||
all_groups = set(extras.keys()) - {"accept-rom-license"}
|
all_groups = set(extras.keys()) - {"accept-rom-license"}
|
||||||
|
@@ -7,3 +7,4 @@ imageio>=2.14.1
|
|||||||
pygame==2.1.0
|
pygame==2.1.0
|
||||||
mujoco_py<2.2,>=2.1
|
mujoco_py<2.2,>=2.1
|
||||||
pytest==7.0.1
|
pytest==7.0.1
|
||||||
|
gym==0.26.2
|
||||||
|
@@ -3,12 +3,12 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium
|
||||||
from gymnasium.spaces import Discrete
|
from gymnasium.spaces import Discrete
|
||||||
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv
|
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv
|
||||||
|
|
||||||
|
|
||||||
class LegacyEnvExplicit(LegacyEnv, gym.Env):
|
class LegacyEnvExplicit(LegacyEnv, gymnasium.Env):
|
||||||
"""Legacy env that explicitly implements the old API."""
|
"""Legacy env that explicitly implements the old API."""
|
||||||
|
|
||||||
observation_space = Discrete(1)
|
observation_space = Discrete(1)
|
||||||
@@ -37,7 +37,7 @@ class LegacyEnvExplicit(LegacyEnv, gym.Env):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LegacyEnvImplicit(gym.Env):
|
class LegacyEnvImplicit(gymnasium.Env):
|
||||||
"""Legacy env that implicitly implements the old API as a protocol."""
|
"""Legacy env that implicitly implements the old API as a protocol."""
|
||||||
|
|
||||||
observation_space = Discrete(1)
|
observation_space = Discrete(1)
|
||||||
@@ -95,12 +95,12 @@ def test_implicit():
|
|||||||
|
|
||||||
|
|
||||||
def test_make_compatibility_in_spec():
|
def test_make_compatibility_in_spec():
|
||||||
gym.register(
|
gymnasium.register(
|
||||||
id="LegacyTestEnv-v0",
|
id="LegacyTestEnv-v0",
|
||||||
entry_point=LegacyEnvExplicit,
|
entry_point=LegacyEnvExplicit,
|
||||||
apply_api_compatibility=True,
|
apply_api_compatibility=True,
|
||||||
)
|
)
|
||||||
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array")
|
env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array")
|
||||||
assert env.observation_space == Discrete(1)
|
assert env.observation_space == Discrete(1)
|
||||||
assert env.action_space == Discrete(1)
|
assert env.action_space == Discrete(1)
|
||||||
assert env.reset() == (0, {})
|
assert env.reset() == (0, {})
|
||||||
@@ -110,12 +110,12 @@ def test_make_compatibility_in_spec():
|
|||||||
assert isinstance(img, np.ndarray)
|
assert isinstance(img, np.ndarray)
|
||||||
assert img.shape == (1, 1, 3) # type: ignore
|
assert img.shape == (1, 1, 3) # type: ignore
|
||||||
env.close()
|
env.close()
|
||||||
del gym.envs.registration.registry["LegacyTestEnv-v0"]
|
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
|
||||||
|
|
||||||
|
|
||||||
def test_make_compatibility_in_make():
|
def test_make_compatibility_in_make():
|
||||||
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
|
gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
|
||||||
env = gym.make(
|
env = gymnasium.make(
|
||||||
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
|
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
|
||||||
)
|
)
|
||||||
assert env.observation_space == Discrete(1)
|
assert env.observation_space == Discrete(1)
|
||||||
@@ -127,4 +127,4 @@ def test_make_compatibility_in_make():
|
|||||||
assert isinstance(img, np.ndarray)
|
assert isinstance(img, np.ndarray)
|
||||||
assert img.shape == (1, 1, 3) # type: ignore
|
assert img.shape == (1, 1, 3) # type: ignore
|
||||||
env.close()
|
env.close()
|
||||||
del gym.envs.registration.registry["LegacyTestEnv-v0"]
|
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
|
||||||
|
27
tests/envs/test_gym_conversion.py
Normal file
27
tests/envs/test_gym_conversion.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
from gymnasium.utils.env_checker import check_env
|
||||||
|
|
||||||
|
pytest.importorskip("gym")
|
||||||
|
|
||||||
|
import gym # noqa: E402, isort: skip
|
||||||
|
|
||||||
|
ALL_GYM_ENVS = gym.envs.registry.keys()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
|
||||||
|
)
|
||||||
|
def test_gym_conversion_by_id(env_id):
|
||||||
|
env = gymnasium.make("GymV26Environment-v0", env_id=env_id)
|
||||||
|
check_env(env)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
|
||||||
|
)
|
||||||
|
def test_gym_conversion_instantiated(env_id):
|
||||||
|
env = gym.make(env_id)
|
||||||
|
env = gymnasium.make("GymV26Environment-v0", env=env)
|
||||||
|
check_env(env)
|
@@ -17,7 +17,11 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
|
|||||||
if "gymnasium.envs." in env_spec.entry_point:
|
if "gymnasium.envs." in env_spec.entry_point:
|
||||||
try:
|
try:
|
||||||
return env_spec.make(disable_env_checker=True).unwrapped
|
return env_spec.make(disable_env_checker=True).unwrapped
|
||||||
except (ImportError, gym.error.DependencyNotInstalled) as e:
|
except (
|
||||||
|
ImportError,
|
||||||
|
gym.error.DependencyNotInstalled,
|
||||||
|
gym.error.MissingArgument,
|
||||||
|
) as e:
|
||||||
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
|
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user