mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 06:16:32 +00:00
Add shimmy for atari and removes the gym compatibility for the shimmy versions (#125)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from gymnasium.envs.registration import load_env_plugins as _load_env_plugins
|
||||
from gymnasium.envs.registration import make, pprint_registry, register, registry, spec
|
||||
|
||||
# Hook to load plugins from entry points
|
||||
_load_env_plugins()
|
||||
|
||||
"""Registers the internal gym envs then loads the env plugins for module using the entry point."""
|
||||
from gymnasium.envs.registration import (
|
||||
load_env_plugins,
|
||||
make,
|
||||
pprint_registry,
|
||||
register,
|
||||
registry,
|
||||
spec,
|
||||
)
|
||||
|
||||
# Classic
|
||||
# ----------------------------------------
|
||||
@@ -344,9 +347,5 @@ register(
|
||||
)
|
||||
|
||||
|
||||
# Gym conversion
|
||||
# ----------------------------------------
|
||||
register(
|
||||
id="GymV26Environment-v0",
|
||||
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
|
||||
)
|
||||
# Hook to load plugins from entry points
|
||||
load_env_plugins()
|
||||
|
0
gymnasium/envs/external/__init__.py
vendored
0
gymnasium/envs/external/__init__.py
vendored
159
gymnasium/envs/external/gym_env.py
vendored
159
gymnasium/envs/external/gym_env.py
vendored
@@ -1,159 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium import error
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
try:
|
||||
import gym
|
||||
except ImportError as e:
|
||||
GYM_IMPORT_ERROR = e
|
||||
else:
|
||||
GYM_IMPORT_ERROR = None
|
||||
|
||||
|
||||
class GymEnvironment(gymnasium.Env):
|
||||
"""
|
||||
Converts a gym environment to a gymnasium environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_id: Optional[str] = None,
|
||||
make_kwargs: Optional[dict] = None,
|
||||
env: Optional["gym.Env"] = None,
|
||||
):
|
||||
if GYM_IMPORT_ERROR is not None:
|
||||
raise error.DependencyNotInstalled(
|
||||
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
|
||||
)
|
||||
|
||||
if make_kwargs is None:
|
||||
make_kwargs = {}
|
||||
|
||||
if env is not None:
|
||||
self.gym_env = env
|
||||
elif env_id is not None:
|
||||
self.gym_env = gym.make(env_id, **make_kwargs)
|
||||
else:
|
||||
raise gymnasium.error.MissingArgument(
|
||||
"Either env_id or env must be provided to create a legacy gym environment."
|
||||
)
|
||||
self.gym_env = _strip_default_wrappers(self.gym_env)
|
||||
|
||||
self.observation_space = _convert_space(self.gym_env.observation_space)
|
||||
self.action_space = _convert_space(self.gym_env.action_space)
|
||||
|
||||
self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
|
||||
self.render_mode = self.gym_env.render_mode
|
||||
self.reward_range = getattr(self.gym_env, "reward_range", None)
|
||||
self.spec = getattr(self.gym_env, "spec", None)
|
||||
|
||||
def reset(
|
||||
self, seed: Optional[int] = None, options: Optional[dict] = None
|
||||
) -> Tuple[ObsType, dict]:
|
||||
"""Resets the environment.
|
||||
|
||||
Args:
|
||||
seed: the seed to reset the environment with
|
||||
options: the options to reset the environment with
|
||||
|
||||
Returns:
|
||||
(observation, info)
|
||||
"""
|
||||
super().reset(seed=seed)
|
||||
# Options are ignored
|
||||
return self.gym_env.reset(seed=seed, options=options)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
"""Steps through the environment.
|
||||
|
||||
Args:
|
||||
action: action to step through the environment with
|
||||
|
||||
Returns:
|
||||
(observation, reward, terminated, truncated, info)
|
||||
"""
|
||||
return self.gym_env.step(action)
|
||||
|
||||
def render(self):
|
||||
"""Renders the environment.
|
||||
|
||||
Returns:
|
||||
The rendering of the environment, depending on the render mode
|
||||
"""
|
||||
return self.gym_env.render()
|
||||
|
||||
def close(self):
|
||||
"""Closes the environment."""
|
||||
self.gym_env.close()
|
||||
|
||||
def __str__(self):
|
||||
return f"GymEnvironment({self.gym_env})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"GymEnvironment({self.gym_env})"
|
||||
|
||||
|
||||
def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
|
||||
"""Strips builtin wrappers from the environment.
|
||||
|
||||
Args:
|
||||
env: the environment to strip builtin wrappers from
|
||||
|
||||
Returns:
|
||||
The environment without builtin wrappers
|
||||
"""
|
||||
import gym.wrappers
|
||||
|
||||
default_wrappers = (
|
||||
gym.wrappers.render_collection.RenderCollection,
|
||||
gym.wrappers.human_rendering.HumanRendering,
|
||||
)
|
||||
while isinstance(env, default_wrappers):
|
||||
env = env.env
|
||||
return env
|
||||
|
||||
|
||||
def _convert_space(space: "gym.Space") -> gymnasium.Space:
|
||||
"""Converts a gym space to a gymnasium space.
|
||||
|
||||
Args:
|
||||
space: the space to convert
|
||||
|
||||
Returns:
|
||||
The converted space
|
||||
"""
|
||||
if isinstance(space, gym.spaces.Discrete):
|
||||
return gymnasium.spaces.Discrete(n=space.n)
|
||||
elif isinstance(space, gym.spaces.Box):
|
||||
return gymnasium.spaces.Box(
|
||||
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
|
||||
)
|
||||
elif isinstance(space, gym.spaces.MultiDiscrete):
|
||||
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
|
||||
elif isinstance(space, gym.spaces.MultiBinary):
|
||||
return gymnasium.spaces.MultiBinary(n=space.n)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
|
||||
elif isinstance(space, gym.spaces.Dict):
|
||||
return gymnasium.spaces.Dict(
|
||||
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
|
||||
)
|
||||
elif isinstance(space, gym.spaces.Sequence):
|
||||
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
|
||||
elif isinstance(space, gym.spaces.Graph):
|
||||
return gymnasium.spaces.Graph(
|
||||
node_space=_convert_space(space.node_space), # type: ignore
|
||||
edge_space=_convert_space(space.edge_space), # type: ignore
|
||||
)
|
||||
elif isinstance(space, gym.spaces.Text):
|
||||
return gymnasium.spaces.Text(
|
||||
max_length=space.max_length,
|
||||
min_length=space.min_length,
|
||||
charset=space._char_str,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
|
||||
)
|
@@ -3,6 +3,7 @@ import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import logger
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.utils.step_api_compatibility import (
|
||||
convert_to_terminated_truncated_step_api,
|
||||
@@ -62,6 +63,10 @@ class EnvCompatibility(gym.Env):
|
||||
old_env (LegacyEnv): the env to wrap, implemented with the old API
|
||||
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render
|
||||
"""
|
||||
logger.warn(
|
||||
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v28. "
|
||||
"Instead use `gym.make('GymV22Environment-v0', env_name=...)` or `from shimmy import GymV26CompatibilityV0`"
|
||||
)
|
||||
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
|
||||
self.render_mode = render_mode
|
||||
self.reward_range = getattr(old_env, "reward_range", None)
|
||||
|
4
setup.py
4
setup.py
@@ -33,7 +33,7 @@ def get_version():
|
||||
|
||||
# Environment-specific dependencies.
|
||||
extras = {
|
||||
"atari": ["ale-py~=0.8.0"],
|
||||
"atari": ["shimmy[atari]>=0.1.0,<1.0"],
|
||||
"accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"],
|
||||
"box2d": ["box2d-py==2.3.5", "pygame==2.1.0", "swig==4.*"],
|
||||
"classic_control": ["pygame==2.1.0"],
|
||||
@@ -46,7 +46,6 @@ extras = {
|
||||
|
||||
extras["testing"] = list(set(itertools.chain.from_iterable(extras.values()))) + [
|
||||
"pytest==7.1.3",
|
||||
"gym[classic_control, mujoco_py, mujoco, toy_text, other, atari, accept-rom-license]==0.26.2",
|
||||
]
|
||||
|
||||
# All dependency groups - accept rom license as requires user to run
|
||||
@@ -90,6 +89,7 @@ setup(
|
||||
"cloudpickle >= 1.2.0",
|
||||
"importlib_metadata >= 4.8.0; python_version < '3.10'",
|
||||
"gymnasium_notices >= 0.0.1",
|
||||
"shimmy>=0.1.0, <1.0",
|
||||
],
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
@@ -1,43 +0,0 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
from tests.envs.test_envs import CHECK_ENV_IGNORE_WARNINGS
|
||||
|
||||
pytest.importorskip("gym")
|
||||
|
||||
import gym # noqa: E402, isort: skip
|
||||
|
||||
# We do not test Atari environment's here because we check all variants of Pong in test_envs.py (There are too many Atari environments)
|
||||
ALL_GYM_ENVS = [
|
||||
env_id
|
||||
for env_id, spec in gym.envs.registry.items()
|
||||
if ("ale_py" not in spec.entry_point or "Pong" in env_id)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
|
||||
)
|
||||
def test_gym_conversion_by_id(env_id):
|
||||
env = gymnasium.make("GymV26Environment-v0", env_id=env_id).unwrapped
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_env(env, skip_render_check=True)
|
||||
for warning in caught_warnings:
|
||||
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
|
||||
raise gymnasium.error.Error(f"Unexpected warning: {warning.message}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
|
||||
)
|
||||
def test_gym_conversion_instantiated(env_id):
|
||||
env = gym.make(env_id)
|
||||
env = gymnasium.make("GymV26Environment-v0", env=env).unwrapped
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_env(env, skip_render_check=True)
|
||||
for warning in caught_warnings:
|
||||
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
|
||||
raise gymnasium.error.Error(f"Unexpected warning: {warning.message}")
|
@@ -22,33 +22,45 @@ from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
||||
from tests.testing_env import GenericTestEnv, old_step_fn
|
||||
from tests.wrappers.utils import has_wrapper
|
||||
|
||||
gym.register(
|
||||
"RegisterDuringMakeEnv-v0",
|
||||
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
|
||||
)
|
||||
|
||||
gym.register(
|
||||
id="test.ArgumentEnv-v0",
|
||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
||||
kwargs={
|
||||
"arg1": "arg1",
|
||||
"arg2": "arg2",
|
||||
},
|
||||
)
|
||||
@pytest.fixture(scope="function")
|
||||
def register_make_testing_envs():
|
||||
"""Registers testing envs for `gym.make`"""
|
||||
gym.register(
|
||||
"RegisterDuringMakeEnv-v0",
|
||||
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
|
||||
)
|
||||
|
||||
gym.register(
|
||||
id="test/NoHuman-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHuman",
|
||||
)
|
||||
gym.register(
|
||||
id="test/NoHumanOldAPI-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHumanOldAPI",
|
||||
)
|
||||
gym.register(
|
||||
id="test.ArgumentEnv-v0",
|
||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
||||
kwargs={
|
||||
"arg1": "arg1",
|
||||
"arg2": "arg2",
|
||||
},
|
||||
)
|
||||
|
||||
gym.register(
|
||||
id="test/NoHumanNoRGB-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHumanNoRGB",
|
||||
)
|
||||
gym.register(
|
||||
id="test/NoHuman-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHuman",
|
||||
)
|
||||
gym.register(
|
||||
id="test/NoHumanOldAPI-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHumanOldAPI",
|
||||
)
|
||||
|
||||
gym.register(
|
||||
id="test/NoHumanNoRGB-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHumanNoRGB",
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
del gym.envs.registration.registry["RegisterDuringMakeEnv-v0"]
|
||||
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
|
||||
del gym.envs.registration.registry["test/NoHuman-v0"]
|
||||
del gym.envs.registration.registry["test/NoHumanOldAPI-v0"]
|
||||
del gym.envs.registration.registry["test/NoHumanNoRGB-v0"]
|
||||
|
||||
|
||||
def test_make():
|
||||
@@ -70,7 +82,7 @@ def test_make_deprecated():
|
||||
gym.make("Humanoid-v0", disable_env_checker=True)
|
||||
|
||||
|
||||
def test_make_max_episode_steps():
|
||||
def test_make_max_episode_steps(register_make_testing_envs):
|
||||
# Default, uses the spec's
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
assert has_wrapper(env, TimeLimit)
|
||||
@@ -208,7 +220,7 @@ def test_make_order_enforcing():
|
||||
env.close()
|
||||
|
||||
|
||||
def test_make_render_mode():
|
||||
def test_make_render_mode(register_make_testing_envs):
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
assert env.render_mode is None
|
||||
env.close()
|
||||
@@ -293,7 +305,7 @@ def test_make_render_mode():
|
||||
gym.make("CarRacing-v2", render="human")
|
||||
|
||||
|
||||
def test_make_kwargs():
|
||||
def test_make_kwargs(register_make_testing_envs):
|
||||
env = gym.make(
|
||||
"test.ArgumentEnv-v0",
|
||||
arg2="override_arg2",
|
||||
@@ -309,7 +321,7 @@ def test_make_kwargs():
|
||||
env.close()
|
||||
|
||||
|
||||
def test_import_module_during_make():
|
||||
def test_import_module_during_make(register_make_testing_envs):
|
||||
# Test custom environment which is registered at make
|
||||
env = gym.make(
|
||||
"tests.envs.utils:RegisterDuringMakeEnv-v0",
|
||||
|
@@ -1,8 +1,15 @@
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
# To ignore the trailing whitespaces, will need flake to ignore this file.
|
||||
# flake8: noqa
|
||||
|
||||
reduced_registry = {
|
||||
env_id: env_spec
|
||||
for env_id, env_spec in gym.registry.items()
|
||||
if env_spec.entry_point != "shimmy.atari_env:AtariEnv"
|
||||
}
|
||||
|
||||
|
||||
def test_pprint_custom_registry():
|
||||
"""Testing a registry different from default."""
|
||||
|
@@ -8,8 +8,8 @@ import gymnasium as gym
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def register_testing_envs():
|
||||
"""Registers testing environments."""
|
||||
def register_registration_testing_envs():
|
||||
"""Register testing envs for `gym.register`."""
|
||||
namespace = "MyAwesomeNamespace"
|
||||
versioned_name = "MyAwesomeVersionedEnv"
|
||||
unversioned_name = "MyAwesomeUnversionedEnv"
|
||||
@@ -105,7 +105,9 @@ def test_register_error(env_id):
|
||||
("MyAwesomeNamespace/MyAwesomeVersioneEnv", "MyAwesomeVersionedEnv"),
|
||||
],
|
||||
)
|
||||
def test_env_suggestions(register_testing_envs, env_id_input, env_id_suggested):
|
||||
def test_env_suggestions(
|
||||
register_registration_testing_envs, env_id_input, env_id_suggested
|
||||
):
|
||||
with pytest.raises(
|
||||
gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
|
||||
):
|
||||
@@ -124,7 +126,10 @@ def test_env_suggestions(register_testing_envs, env_id_input, env_id_suggested):
|
||||
],
|
||||
)
|
||||
def test_env_version_suggestions(
|
||||
register_testing_envs, env_id_input, suggested_versions, default_version
|
||||
register_registration_testing_envs,
|
||||
env_id_input,
|
||||
suggested_versions,
|
||||
default_version,
|
||||
):
|
||||
if default_version:
|
||||
with pytest.raises(
|
||||
@@ -173,7 +178,7 @@ def test_register_versioned_unversioned():
|
||||
del gym.envs.registry[unversioned_env]
|
||||
|
||||
|
||||
def test_make_latest_versioned_env(register_testing_envs):
|
||||
def test_make_latest_versioned_env(register_registration_testing_envs):
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match=re.escape(
|
||||
|
Reference in New Issue
Block a user