mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Gym-Gymnasium compatibility converter (#61)
This commit is contained in:
committed by
GitHub
parent
e7c8a8cb59
commit
8b81b7dcc2
@@ -318,3 +318,11 @@ register(
|
||||
entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv",
|
||||
max_episode_steps=1000,
|
||||
)
|
||||
|
||||
|
||||
# Gym conversion
|
||||
# ----------------------------------------
|
||||
register(
|
||||
id="GymV26Environment-v0",
|
||||
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
|
||||
)
|
||||
|
0
gymnasium/envs/external/__init__.py
vendored
Normal file
0
gymnasium/envs/external/__init__.py
vendored
Normal file
159
gymnasium/envs/external/gym_env.py
vendored
Normal file
159
gymnasium/envs/external/gym_env.py
vendored
Normal file
@@ -0,0 +1,159 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium import error
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
try:
|
||||
import gym
|
||||
import gym.wrappers
|
||||
except ImportError as e:
|
||||
GYM_IMPORT_ERROR = e
|
||||
else:
|
||||
GYM_IMPORT_ERROR = None
|
||||
|
||||
|
||||
class GymEnvironment(gymnasium.Env):
|
||||
"""
|
||||
Converts a gym environment to a gymnasium environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_id: Optional[str] = None,
|
||||
make_kwargs: Optional[dict] = None,
|
||||
env: Optional["gym.Env"] = None,
|
||||
):
|
||||
if GYM_IMPORT_ERROR is not None:
|
||||
raise error.DependencyNotInstalled(
|
||||
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
|
||||
)
|
||||
|
||||
if make_kwargs is None:
|
||||
make_kwargs = {}
|
||||
|
||||
if env is not None:
|
||||
self.gym_env = env
|
||||
elif env_id is not None:
|
||||
self.gym_env = gym.make(env_id, **make_kwargs)
|
||||
else:
|
||||
raise gymnasium.error.MissingArgument(
|
||||
"Either env_id or env must be provided to create a legacy gym environment."
|
||||
)
|
||||
self.gym_env = _strip_default_wrappers(self.gym_env)
|
||||
|
||||
self.observation_space = _convert_space(self.gym_env.observation_space)
|
||||
self.action_space = _convert_space(self.gym_env.action_space)
|
||||
|
||||
self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
|
||||
self.render_mode = self.gym_env.render_mode
|
||||
self.reward_range = getattr(self.gym_env, "reward_range", None)
|
||||
self.spec = getattr(self.gym_env, "spec", None)
|
||||
|
||||
def reset(
|
||||
self, seed: Optional[int] = None, options: Optional[dict] = None
|
||||
) -> Tuple[ObsType, dict]:
|
||||
"""Resets the environment.
|
||||
|
||||
Args:
|
||||
seed: the seed to reset the environment with
|
||||
options: the options to reset the environment with
|
||||
|
||||
Returns:
|
||||
(observation, info)
|
||||
"""
|
||||
super().reset(seed=seed)
|
||||
# Options are ignored
|
||||
return self.gym_env.reset(seed=seed, options=options)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
"""Steps through the environment.
|
||||
|
||||
Args:
|
||||
action: action to step through the environment with
|
||||
|
||||
Returns:
|
||||
(observation, reward, terminated, truncated, info)
|
||||
"""
|
||||
return self.gym_env.step(action)
|
||||
|
||||
def render(self):
|
||||
"""Renders the environment.
|
||||
|
||||
Returns:
|
||||
The rendering of the environment, depending on the render mode
|
||||
"""
|
||||
return self.gym_env.render()
|
||||
|
||||
def close(self):
|
||||
"""Closes the environment."""
|
||||
self.gym_env.close()
|
||||
|
||||
def __str__(self):
|
||||
return f"GymEnvironment({self.gym_env})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"GymEnvironment({self.gym_env})"
|
||||
|
||||
|
||||
def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
|
||||
"""Strips builtin wrappers from the environment.
|
||||
|
||||
Args:
|
||||
env: the environment to strip builtin wrappers from
|
||||
|
||||
Returns:
|
||||
The environment without builtin wrappers
|
||||
"""
|
||||
|
||||
default_wrappers = (
|
||||
gym.wrappers.render_collection.RenderCollection,
|
||||
gym.wrappers.human_rendering.HumanRendering,
|
||||
)
|
||||
while isinstance(env, default_wrappers):
|
||||
env = env.env
|
||||
return env
|
||||
|
||||
|
||||
def _convert_space(space: "gym.Space") -> gymnasium.Space:
|
||||
"""Converts a gym space to a gymnasium space.
|
||||
|
||||
Args:
|
||||
space: the space to convert
|
||||
|
||||
Returns:
|
||||
The converted space
|
||||
"""
|
||||
if isinstance(space, gym.spaces.Discrete):
|
||||
return gymnasium.spaces.Discrete(n=space.n)
|
||||
elif isinstance(space, gym.spaces.Box):
|
||||
return gymnasium.spaces.Box(
|
||||
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
|
||||
)
|
||||
elif isinstance(space, gym.spaces.MultiDiscrete):
|
||||
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
|
||||
elif isinstance(space, gym.spaces.MultiBinary):
|
||||
return gymnasium.spaces.MultiBinary(n=space.n)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
|
||||
elif isinstance(space, gym.spaces.Dict):
|
||||
return gymnasium.spaces.Dict(
|
||||
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
|
||||
)
|
||||
elif isinstance(space, gym.spaces.Sequence):
|
||||
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
|
||||
elif isinstance(space, gym.spaces.Graph):
|
||||
return gymnasium.spaces.Graph(
|
||||
node_space=_convert_space(space.node_space), # type: ignore
|
||||
edge_space=_convert_space(space.edge_space), # type: ignore
|
||||
)
|
||||
elif isinstance(space, gym.spaces.Text):
|
||||
return gymnasium.spaces.Text(
|
||||
max_length=space.max_length,
|
||||
min_length=space.min_length,
|
||||
charset=space._char_str,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
|
||||
)
|
@@ -65,6 +65,10 @@ class InvalidAction(Error):
|
||||
"""Raised when the user performs an action not contained within the action space."""
|
||||
|
||||
|
||||
class MissingArgument(Error):
|
||||
"""Raised when a required argument in the initializer is missing."""
|
||||
|
||||
|
||||
# API errors
|
||||
|
||||
|
||||
|
2
setup.py
2
setup.py
@@ -47,7 +47,7 @@ extras = {
|
||||
testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
|
||||
extras["testing"] = list(
|
||||
set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
|
||||
) + ["pytest==7.0.1"]
|
||||
) + ["pytest==7.0.1", "gym==0.26.2"]
|
||||
|
||||
# All dependency groups - accept rom license as requires user to run
|
||||
all_groups = set(extras.keys()) - {"accept-rom-license"}
|
||||
|
@@ -7,3 +7,4 @@ imageio>=2.14.1
|
||||
pygame==2.1.0
|
||||
mujoco_py<2.2,>=2.1
|
||||
pytest==7.0.1
|
||||
gym==0.26.2
|
||||
|
@@ -3,12 +3,12 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
import gymnasium
|
||||
from gymnasium.spaces import Discrete
|
||||
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv
|
||||
|
||||
|
||||
class LegacyEnvExplicit(LegacyEnv, gym.Env):
|
||||
class LegacyEnvExplicit(LegacyEnv, gymnasium.Env):
|
||||
"""Legacy env that explicitly implements the old API."""
|
||||
|
||||
observation_space = Discrete(1)
|
||||
@@ -37,7 +37,7 @@ class LegacyEnvExplicit(LegacyEnv, gym.Env):
|
||||
pass
|
||||
|
||||
|
||||
class LegacyEnvImplicit(gym.Env):
|
||||
class LegacyEnvImplicit(gymnasium.Env):
|
||||
"""Legacy env that implicitly implements the old API as a protocol."""
|
||||
|
||||
observation_space = Discrete(1)
|
||||
@@ -95,12 +95,12 @@ def test_implicit():
|
||||
|
||||
|
||||
def test_make_compatibility_in_spec():
|
||||
gym.register(
|
||||
gymnasium.register(
|
||||
id="LegacyTestEnv-v0",
|
||||
entry_point=LegacyEnvExplicit,
|
||||
apply_api_compatibility=True,
|
||||
)
|
||||
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array")
|
||||
env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array")
|
||||
assert env.observation_space == Discrete(1)
|
||||
assert env.action_space == Discrete(1)
|
||||
assert env.reset() == (0, {})
|
||||
@@ -110,12 +110,12 @@ def test_make_compatibility_in_spec():
|
||||
assert isinstance(img, np.ndarray)
|
||||
assert img.shape == (1, 1, 3) # type: ignore
|
||||
env.close()
|
||||
del gym.envs.registration.registry["LegacyTestEnv-v0"]
|
||||
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
|
||||
|
||||
|
||||
def test_make_compatibility_in_make():
|
||||
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
|
||||
env = gym.make(
|
||||
gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
|
||||
env = gymnasium.make(
|
||||
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
|
||||
)
|
||||
assert env.observation_space == Discrete(1)
|
||||
@@ -127,4 +127,4 @@ def test_make_compatibility_in_make():
|
||||
assert isinstance(img, np.ndarray)
|
||||
assert img.shape == (1, 1, 3) # type: ignore
|
||||
env.close()
|
||||
del gym.envs.registration.registry["LegacyTestEnv-v0"]
|
||||
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
|
||||
|
27
tests/envs/test_gym_conversion.py
Normal file
27
tests/envs/test_gym_conversion.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
pytest.importorskip("gym")
|
||||
|
||||
import gym # noqa: E402, isort: skip
|
||||
|
||||
ALL_GYM_ENVS = gym.envs.registry.keys()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
|
||||
)
|
||||
def test_gym_conversion_by_id(env_id):
|
||||
env = gymnasium.make("GymV26Environment-v0", env_id=env_id)
|
||||
check_env(env)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
|
||||
)
|
||||
def test_gym_conversion_instantiated(env_id):
|
||||
env = gym.make(env_id)
|
||||
env = gymnasium.make("GymV26Environment-v0", env=env)
|
||||
check_env(env)
|
@@ -17,7 +17,11 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
|
||||
if "gymnasium.envs." in env_spec.entry_point:
|
||||
try:
|
||||
return env_spec.make(disable_env_checker=True).unwrapped
|
||||
except (ImportError, gym.error.DependencyNotInstalled) as e:
|
||||
except (
|
||||
ImportError,
|
||||
gym.error.DependencyNotInstalled,
|
||||
gym.error.MissingArgument,
|
||||
) as e:
|
||||
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
|
||||
return None
|
||||
|
||||
|
Reference in New Issue
Block a user