Add shimmy for atari and removes the gym compatibility for the shimmy versions (#125)

This commit is contained in:
Mark Towers
2022-12-01 12:04:57 +00:00
committed by GitHub
parent 203e4e7920
commit 320b52c041
9 changed files with 75 additions and 249 deletions

View File

@@ -1,9 +1,12 @@
from gymnasium.envs.registration import load_env_plugins as _load_env_plugins
from gymnasium.envs.registration import make, pprint_registry, register, registry, spec
# Hook to load plugins from entry points
_load_env_plugins()
"""Registers the internal gym envs then loads the env plugins for module using the entry point."""
from gymnasium.envs.registration import (
load_env_plugins,
make,
pprint_registry,
register,
registry,
spec,
)
# Classic
# ----------------------------------------
@@ -344,9 +347,5 @@ register(
)
# Gym conversion
# ----------------------------------------
register(
id="GymV26Environment-v0",
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
)
# Hook to load plugins from entry points
load_env_plugins()

View File

View File

@@ -1,159 +0,0 @@
from typing import Optional, Tuple
import gymnasium
from gymnasium import error
from gymnasium.core import ActType, ObsType
try:
import gym
except ImportError as e:
GYM_IMPORT_ERROR = e
else:
GYM_IMPORT_ERROR = None
class GymEnvironment(gymnasium.Env):
"""
Converts a gym environment to a gymnasium environment.
"""
def __init__(
self,
env_id: Optional[str] = None,
make_kwargs: Optional[dict] = None,
env: Optional["gym.Env"] = None,
):
if GYM_IMPORT_ERROR is not None:
raise error.DependencyNotInstalled(
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
)
if make_kwargs is None:
make_kwargs = {}
if env is not None:
self.gym_env = env
elif env_id is not None:
self.gym_env = gym.make(env_id, **make_kwargs)
else:
raise gymnasium.error.MissingArgument(
"Either env_id or env must be provided to create a legacy gym environment."
)
self.gym_env = _strip_default_wrappers(self.gym_env)
self.observation_space = _convert_space(self.gym_env.observation_space)
self.action_space = _convert_space(self.gym_env.action_space)
self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
self.render_mode = self.gym_env.render_mode
self.reward_range = getattr(self.gym_env, "reward_range", None)
self.spec = getattr(self.gym_env, "spec", None)
def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.
Args:
seed: the seed to reset the environment with
options: the options to reset the environment with
Returns:
(observation, info)
"""
super().reset(seed=seed)
# Options are ignored
return self.gym_env.reset(seed=seed, options=options)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info)
"""
return self.gym_env.step(action)
def render(self):
"""Renders the environment.
Returns:
The rendering of the environment, depending on the render mode
"""
return self.gym_env.render()
def close(self):
"""Closes the environment."""
self.gym_env.close()
def __str__(self):
return f"GymEnvironment({self.gym_env})"
def __repr__(self):
return f"GymEnvironment({self.gym_env})"
def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
"""Strips builtin wrappers from the environment.
Args:
env: the environment to strip builtin wrappers from
Returns:
The environment without builtin wrappers
"""
import gym.wrappers
default_wrappers = (
gym.wrappers.render_collection.RenderCollection,
gym.wrappers.human_rendering.HumanRendering,
)
while isinstance(env, default_wrappers):
env = env.env
return env
def _convert_space(space: "gym.Space") -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.
Args:
space: the space to convert
Returns:
The converted space
"""
if isinstance(space, gym.spaces.Discrete):
return gymnasium.spaces.Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
return gymnasium.spaces.Box(
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
)
elif isinstance(space, gym.spaces.MultiDiscrete):
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.MultiBinary):
return gymnasium.spaces.MultiBinary(n=space.n)
elif isinstance(space, gym.spaces.Tuple):
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return gymnasium.spaces.Dict(
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
)
elif isinstance(space, gym.spaces.Sequence):
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
elif isinstance(space, gym.spaces.Graph):
return gymnasium.spaces.Graph(
node_space=_convert_space(space.node_space), # type: ignore
edge_space=_convert_space(space.edge_space), # type: ignore
)
elif isinstance(space, gym.spaces.Text):
return gymnasium.spaces.Text(
max_length=space.max_length,
min_length=space.min_length,
charset=space._char_str,
)
else:
raise NotImplementedError(
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
)

View File

@@ -3,6 +3,7 @@ import sys
from typing import Any, Dict, Optional, Tuple
import gymnasium as gym
from gymnasium import logger
from gymnasium.core import ObsType
from gymnasium.utils.step_api_compatibility import (
convert_to_terminated_truncated_step_api,
@@ -62,6 +63,10 @@ class EnvCompatibility(gym.Env):
old_env (LegacyEnv): the env to wrap, implemented with the old API
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render
"""
logger.warn(
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v28. "
"Instead use `gym.make('GymV22Environment-v0', env_name=...)` or `from shimmy import GymV26CompatibilityV0`"
)
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
self.render_mode = render_mode
self.reward_range = getattr(old_env, "reward_range", None)

View File

@@ -33,7 +33,7 @@ def get_version():
# Environment-specific dependencies.
extras = {
"atari": ["ale-py~=0.8.0"],
"atari": ["shimmy[atari]>=0.1.0,<1.0"],
"accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"],
"box2d": ["box2d-py==2.3.5", "pygame==2.1.0", "swig==4.*"],
"classic_control": ["pygame==2.1.0"],
@@ -46,7 +46,6 @@ extras = {
extras["testing"] = list(set(itertools.chain.from_iterable(extras.values()))) + [
"pytest==7.1.3",
"gym[classic_control, mujoco_py, mujoco, toy_text, other, atari, accept-rom-license]==0.26.2",
]
# All dependency groups - accept rom license as requires user to run
@@ -90,6 +89,7 @@ setup(
"cloudpickle >= 1.2.0",
"importlib_metadata >= 4.8.0; python_version < '3.10'",
"gymnasium_notices >= 0.0.1",
"shimmy>=0.1.0, <1.0",
],
classifiers=[
"Programming Language :: Python :: 3",

View File

@@ -1,43 +0,0 @@
import warnings
import pytest
import gymnasium
from gymnasium.utils.env_checker import check_env
from tests.envs.test_envs import CHECK_ENV_IGNORE_WARNINGS
pytest.importorskip("gym")
import gym # noqa: E402, isort: skip
# We do not test Atari environment's here because we check all variants of Pong in test_envs.py (There are too many Atari environments)
ALL_GYM_ENVS = [
env_id
for env_id, spec in gym.envs.registry.items()
if ("ale_py" not in spec.entry_point or "Pong" in env_id)
]
@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_by_id(env_id):
env = gymnasium.make("GymV26Environment-v0", env_id=env_id).unwrapped
with warnings.catch_warnings(record=True) as caught_warnings:
check_env(env, skip_render_check=True)
for warning in caught_warnings:
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise gymnasium.error.Error(f"Unexpected warning: {warning.message}")
@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_instantiated(env_id):
env = gym.make(env_id)
env = gymnasium.make("GymV26Environment-v0", env=env).unwrapped
with warnings.catch_warnings(record=True) as caught_warnings:
check_env(env, skip_render_check=True)
for warning in caught_warnings:
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise gymnasium.error.Error(f"Unexpected warning: {warning.message}")

View File

@@ -22,33 +22,45 @@ from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
from tests.testing_env import GenericTestEnv, old_step_fn
from tests.wrappers.utils import has_wrapper
gym.register(
"RegisterDuringMakeEnv-v0",
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
)
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
@pytest.fixture(scope="function")
def register_make_testing_envs():
"""Registers testing envs for `gym.make`"""
gym.register(
"RegisterDuringMakeEnv-v0",
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
)
gym.register(
id="test/NoHuman-v0",
entry_point="tests.envs.utils_envs:NoHuman",
)
gym.register(
id="test/NoHumanOldAPI-v0",
entry_point="tests.envs.utils_envs:NoHumanOldAPI",
)
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
gym.register(
id="test/NoHumanNoRGB-v0",
entry_point="tests.envs.utils_envs:NoHumanNoRGB",
)
gym.register(
id="test/NoHuman-v0",
entry_point="tests.envs.utils_envs:NoHuman",
)
gym.register(
id="test/NoHumanOldAPI-v0",
entry_point="tests.envs.utils_envs:NoHumanOldAPI",
)
gym.register(
id="test/NoHumanNoRGB-v0",
entry_point="tests.envs.utils_envs:NoHumanNoRGB",
)
yield
del gym.envs.registration.registry["RegisterDuringMakeEnv-v0"]
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
del gym.envs.registration.registry["test/NoHuman-v0"]
del gym.envs.registration.registry["test/NoHumanOldAPI-v0"]
del gym.envs.registration.registry["test/NoHumanNoRGB-v0"]
def test_make():
@@ -70,7 +82,7 @@ def test_make_deprecated():
gym.make("Humanoid-v0", disable_env_checker=True)
def test_make_max_episode_steps():
def test_make_max_episode_steps(register_make_testing_envs):
# Default, uses the spec's
env = gym.make("CartPole-v1", disable_env_checker=True)
assert has_wrapper(env, TimeLimit)
@@ -208,7 +220,7 @@ def test_make_order_enforcing():
env.close()
def test_make_render_mode():
def test_make_render_mode(register_make_testing_envs):
env = gym.make("CartPole-v1", disable_env_checker=True)
assert env.render_mode is None
env.close()
@@ -293,7 +305,7 @@ def test_make_render_mode():
gym.make("CarRacing-v2", render="human")
def test_make_kwargs():
def test_make_kwargs(register_make_testing_envs):
env = gym.make(
"test.ArgumentEnv-v0",
arg2="override_arg2",
@@ -309,7 +321,7 @@ def test_make_kwargs():
env.close()
def test_import_module_during_make():
def test_import_module_during_make(register_make_testing_envs):
# Test custom environment which is registered at make
env = gym.make(
"tests.envs.utils:RegisterDuringMakeEnv-v0",

View File

@@ -1,8 +1,15 @@
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
# To ignore the trailing whitespaces, will need flake to ignore this file.
# flake8: noqa
reduced_registry = {
env_id: env_spec
for env_id, env_spec in gym.registry.items()
if env_spec.entry_point != "shimmy.atari_env:AtariEnv"
}
def test_pprint_custom_registry():
"""Testing a registry different from default."""

View File

@@ -8,8 +8,8 @@ import gymnasium as gym
@pytest.fixture(scope="function")
def register_testing_envs():
"""Registers testing environments."""
def register_registration_testing_envs():
"""Register testing envs for `gym.register`."""
namespace = "MyAwesomeNamespace"
versioned_name = "MyAwesomeVersionedEnv"
unversioned_name = "MyAwesomeUnversionedEnv"
@@ -105,7 +105,9 @@ def test_register_error(env_id):
("MyAwesomeNamespace/MyAwesomeVersioneEnv", "MyAwesomeVersionedEnv"),
],
)
def test_env_suggestions(register_testing_envs, env_id_input, env_id_suggested):
def test_env_suggestions(
register_registration_testing_envs, env_id_input, env_id_suggested
):
with pytest.raises(
gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
):
@@ -124,7 +126,10 @@ def test_env_suggestions(register_testing_envs, env_id_input, env_id_suggested):
],
)
def test_env_version_suggestions(
register_testing_envs, env_id_input, suggested_versions, default_version
register_registration_testing_envs,
env_id_input,
suggested_versions,
default_version,
):
if default_version:
with pytest.raises(
@@ -173,7 +178,7 @@ def test_register_versioned_unversioned():
del gym.envs.registry[unversioned_env]
def test_make_latest_versioned_env(register_testing_envs):
def test_make_latest_versioned_env(register_registration_testing_envs):
with pytest.warns(
UserWarning,
match=re.escape(