mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Remove AtariEnv in favour of official ALE Python package (#2348)
* Remove AtariEnv in favour of official ALE Python * More robust frame stacking test case * Atari documentation update
This commit is contained in:
committed by
GitHub
parent
263a3419ef
commit
f6742ea808
@@ -14,22 +14,27 @@ It's worth browsing through both.
|
||||
|
||||
### Atari
|
||||
|
||||
The Atari environments are a variety of Atari video games. If you didn't
|
||||
do the full install, you can install dependencies via `pip install -e
|
||||
'.[atari]'` (you'll need `cmake` installed) and then get started as
|
||||
follows:
|
||||
The Atari environments are provided by the [Arcade Learning Environment](https://github.com/mgbellemare/Arcade-Learning-Environment) (ALE).
|
||||
If you didn't do the full install, you can install dependencies via `pip install -e '.[atari]'` which will install `ale-py`.
|
||||
You can then create any of the legacy Atari environments as such:
|
||||
``` python
|
||||
import gym
|
||||
|
||||
env = gym.make('SpaceInvaders-v4')
|
||||
```
|
||||
|
||||
Newer versions of the Atari environments live in the [ALE](https://github.com/mgbellemare/Arcade-Learning-Environment) repository
|
||||
and are namespaced with `ALE`. For example, the `v5` environments can be included as such:
|
||||
|
||||
```python
|
||||
import gym
|
||||
env = gym.make('SpaceInvaders-v4')
|
||||
env.reset()
|
||||
env.render()
|
||||
import ale_py
|
||||
|
||||
env = gym.make('ALE/SpaceInvaders-v5')
|
||||
```
|
||||
|
||||
This will install `atari-py`, which automatically compiles the [Arcade
|
||||
Learning Environment](https://github.com/mgbellemare/Arcade-Learning-Environment#:~:text=The%20Arcade%20Learning%20Environment%20(ALE)%20is%20a%20simple%20object%2D,of%20emulation%20from%20agent%20design.). This
|
||||
can take quite a while (a few minutes on a decent laptop), so just be
|
||||
prepared.
|
||||
Note: ROMs are not distributed by the ALE but tools are provided in the ALE to manage ROMs.
|
||||
Please see the project's documentation on importing ROMs.
|
||||
|
||||
### Box2d
|
||||
|
||||
|
@@ -613,7 +613,6 @@ for reward_type in ["sparse", "dense"]:
|
||||
# Atari
|
||||
# ----------------------------------------
|
||||
|
||||
# # print ', '.join(["'{}'".format(name.split('.')[0]) for name in atari_py.list_games()])
|
||||
for game in [
|
||||
"adventure",
|
||||
"air_raid",
|
||||
@@ -678,7 +677,7 @@ for game in [
|
||||
"yars_revenge",
|
||||
"zaxxon",
|
||||
]:
|
||||
for obs_type in ["image", "ram"]:
|
||||
for obs_type in ["rgb", "ram"]:
|
||||
# space_invaders should yield SpaceInvaders-v0 and SpaceInvaders-ram-v0
|
||||
name = "".join([g.capitalize() for g in game.split("_")])
|
||||
if obs_type == "ram":
|
||||
@@ -694,7 +693,7 @@ for game in [
|
||||
|
||||
register(
|
||||
id="{}-v0".format(name),
|
||||
entry_point="gym.envs.atari:AtariEnv",
|
||||
entry_point="ale_py.gym:ALGymEnv",
|
||||
kwargs={
|
||||
"game": game,
|
||||
"obs_type": obs_type,
|
||||
@@ -706,7 +705,7 @@ for game in [
|
||||
|
||||
register(
|
||||
id="{}-v4".format(name),
|
||||
entry_point="gym.envs.atari:AtariEnv",
|
||||
entry_point="ale_py.gym:ALGymEnv",
|
||||
kwargs={"game": game, "obs_type": obs_type},
|
||||
max_episode_steps=100000,
|
||||
nondeterministic=nondeterministic,
|
||||
@@ -721,7 +720,7 @@ for game in [
|
||||
# Use a deterministic frame skip.
|
||||
register(
|
||||
id="{}Deterministic-v0".format(name),
|
||||
entry_point="gym.envs.atari:AtariEnv",
|
||||
entry_point="ale_py.gym:ALGymEnv",
|
||||
kwargs={
|
||||
"game": game,
|
||||
"obs_type": obs_type,
|
||||
@@ -734,7 +733,7 @@ for game in [
|
||||
|
||||
register(
|
||||
id="{}Deterministic-v4".format(name),
|
||||
entry_point="gym.envs.atari:AtariEnv",
|
||||
entry_point="ale_py.gym:ALGymEnv",
|
||||
kwargs={"game": game, "obs_type": obs_type, "frameskip": frameskip},
|
||||
max_episode_steps=100000,
|
||||
nondeterministic=nondeterministic,
|
||||
@@ -742,7 +741,7 @@ for game in [
|
||||
|
||||
register(
|
||||
id="{}NoFrameskip-v0".format(name),
|
||||
entry_point="gym.envs.atari:AtariEnv",
|
||||
entry_point="ale_py.gym:ALGymEnv",
|
||||
kwargs={
|
||||
"game": game,
|
||||
"obs_type": obs_type,
|
||||
@@ -757,7 +756,7 @@ for game in [
|
||||
# deterministic environments.)
|
||||
register(
|
||||
id="{}NoFrameskip-v4".format(name),
|
||||
entry_point="gym.envs.atari:AtariEnv",
|
||||
entry_point="ale_py.gym:ALGymEnv",
|
||||
kwargs={
|
||||
"game": game,
|
||||
"obs_type": obs_type,
|
||||
|
@@ -1 +0,0 @@
|
||||
from gym.envs.atari.atari_env import AtariEnv
|
@@ -1,253 +0,0 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import gym
|
||||
from gym import error, spaces
|
||||
from gym import utils
|
||||
from gym.utils import seeding
|
||||
|
||||
try:
|
||||
import atari_py
|
||||
except ImportError as e:
|
||||
raise error.DependencyNotInstalled(
|
||||
"{}. (HINT: you can install Atari dependencies by running "
|
||||
"'pip install gym[atari]'.)".format(e)
|
||||
)
|
||||
|
||||
|
||||
def to_ram(ale):
|
||||
ram_size = ale.getRAMSize()
|
||||
ram = np.zeros((ram_size), dtype=np.uint8)
|
||||
ale.getRAM(ram)
|
||||
return ram
|
||||
|
||||
|
||||
class AtariEnv(gym.Env, utils.EzPickle):
|
||||
metadata = {"render.modes": ["human", "rgb_array"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
game="pong",
|
||||
mode=None,
|
||||
difficulty=None,
|
||||
obs_type="ram",
|
||||
frameskip=(2, 5),
|
||||
repeat_action_probability=0.0,
|
||||
full_action_space=False,
|
||||
):
|
||||
"""Frameskip should be either a tuple (indicating a random range to
|
||||
choose from, with the top value exclude), or an int."""
|
||||
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
game,
|
||||
mode,
|
||||
difficulty,
|
||||
obs_type,
|
||||
frameskip,
|
||||
repeat_action_probability,
|
||||
full_action_space,
|
||||
)
|
||||
assert obs_type in ("ram", "image")
|
||||
|
||||
self.game = game
|
||||
self.game_path = atari_py.get_game_path(game)
|
||||
self.game_mode = mode
|
||||
self.game_difficulty = difficulty
|
||||
|
||||
if not os.path.exists(self.game_path):
|
||||
msg = "You asked for game %s but path %s does not exist"
|
||||
raise IOError(msg % (game, self.game_path))
|
||||
self._obs_type = obs_type
|
||||
self.frameskip = frameskip
|
||||
self.ale = atari_py.ALEInterface()
|
||||
self.viewer = None
|
||||
|
||||
# Tune (or disable) ALE's action repeat:
|
||||
# https://github.com/openai/gym/issues/349
|
||||
assert isinstance(
|
||||
repeat_action_probability, (float, int)
|
||||
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
|
||||
self.ale.setFloat(
|
||||
"repeat_action_probability".encode("utf-8"), repeat_action_probability
|
||||
)
|
||||
|
||||
self.seed()
|
||||
|
||||
self._action_set = (
|
||||
self.ale.getLegalActionSet()
|
||||
if full_action_space
|
||||
else self.ale.getMinimalActionSet()
|
||||
)
|
||||
self.action_space = spaces.Discrete(len(self._action_set))
|
||||
|
||||
(screen_width, screen_height) = self.ale.getScreenDims()
|
||||
if self._obs_type == "ram":
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, dtype=np.uint8, shape=(128,)
|
||||
)
|
||||
elif self._obs_type == "image":
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
|
||||
)
|
||||
else:
|
||||
raise error.Error(
|
||||
"Unrecognized observation type: {}".format(self._obs_type)
|
||||
)
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed1 = seeding.np_random(seed)
|
||||
# Derive a random seed. This gets passed as a uint, but gets
|
||||
# checked as an int elsewhere, so we need to keep it below
|
||||
# 2**31.
|
||||
seed2 = seeding.hash_seed(seed1 + 1) % 2 ** 31
|
||||
# Empirically, we need to seed before loading the ROM.
|
||||
self.ale.setInt(b"random_seed", seed2)
|
||||
self.ale.loadROM(self.game_path)
|
||||
|
||||
if self.game_mode is not None:
|
||||
modes = self.ale.getAvailableModes()
|
||||
|
||||
assert self.game_mode in modes, (
|
||||
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}'
|
||||
).format(self.game_mode, self.game, modes)
|
||||
self.ale.setMode(self.game_mode)
|
||||
|
||||
if self.game_difficulty is not None:
|
||||
difficulties = self.ale.getAvailableDifficulties()
|
||||
|
||||
assert self.game_difficulty in difficulties, (
|
||||
'Invalid game difficulty "{}" for game {}.\nAvailable difficulties are: {}'
|
||||
).format(self.game_difficulty, self.game, difficulties)
|
||||
self.ale.setDifficulty(self.game_difficulty)
|
||||
|
||||
return [seed1, seed2]
|
||||
|
||||
def step(self, a):
|
||||
reward = 0.0
|
||||
action = self._action_set[a]
|
||||
|
||||
if isinstance(self.frameskip, int):
|
||||
num_steps = self.frameskip
|
||||
else:
|
||||
num_steps = self.np_random.randint(self.frameskip[0], self.frameskip[1])
|
||||
for _ in range(num_steps):
|
||||
reward += self.ale.act(action)
|
||||
ob = self._get_obs()
|
||||
|
||||
return ob, reward, self.ale.game_over(), {"ale.lives": self.ale.lives()}
|
||||
|
||||
def _get_image(self):
|
||||
return self.ale.getScreenRGB2()
|
||||
|
||||
def _get_ram(self):
|
||||
return to_ram(self.ale)
|
||||
|
||||
@property
|
||||
def _n_actions(self):
|
||||
return len(self._action_set)
|
||||
|
||||
def _get_obs(self):
|
||||
if self._obs_type == "ram":
|
||||
return self._get_ram()
|
||||
elif self._obs_type == "image":
|
||||
img = self._get_image()
|
||||
return img
|
||||
|
||||
# return: (states, observations)
|
||||
def reset(self):
|
||||
self.ale.reset_game()
|
||||
return self._get_obs()
|
||||
|
||||
def render(self, mode="human"):
|
||||
img = self._get_image()
|
||||
if mode == "rgb_array":
|
||||
return img
|
||||
elif mode == "human":
|
||||
from gym.envs.classic_control import rendering
|
||||
|
||||
if self.viewer is None:
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
self.viewer.imshow(img)
|
||||
return self.viewer.isopen
|
||||
|
||||
def close(self):
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
self.viewer = None
|
||||
|
||||
def get_action_meanings(self):
|
||||
return [ACTION_MEANING[i] for i in self._action_set]
|
||||
|
||||
def get_keys_to_action(self):
|
||||
KEYWORD_TO_KEY = {
|
||||
"UP": ord("w"),
|
||||
"DOWN": ord("s"),
|
||||
"LEFT": ord("a"),
|
||||
"RIGHT": ord("d"),
|
||||
"FIRE": ord(" "),
|
||||
}
|
||||
|
||||
keys_to_action = {}
|
||||
|
||||
for action_id, action_meaning in enumerate(self.get_action_meanings()):
|
||||
keys = []
|
||||
for keyword, key in KEYWORD_TO_KEY.items():
|
||||
if keyword in action_meaning:
|
||||
keys.append(key)
|
||||
keys = tuple(sorted(keys))
|
||||
|
||||
assert keys not in keys_to_action
|
||||
keys_to_action[keys] = action_id
|
||||
|
||||
return keys_to_action
|
||||
|
||||
def clone_state(self):
|
||||
"""Clone emulator state w/o system state. Restoring this state will
|
||||
*not* give an identical environment. For complete cloning and restoring
|
||||
of the full state, see `{clone,restore}_full_state()`."""
|
||||
state_ref = self.ale.cloneState()
|
||||
state = self.ale.encodeState(state_ref)
|
||||
self.ale.deleteState(state_ref)
|
||||
return state
|
||||
|
||||
def restore_state(self, state):
|
||||
"""Restore emulator state w/o system state."""
|
||||
state_ref = self.ale.decodeState(state)
|
||||
self.ale.restoreState(state_ref)
|
||||
self.ale.deleteState(state_ref)
|
||||
|
||||
def clone_full_state(self):
|
||||
"""Clone emulator state w/ system state including pseudorandomness.
|
||||
Restoring this state will give an identical environment."""
|
||||
state_ref = self.ale.cloneSystemState()
|
||||
state = self.ale.encodeState(state_ref)
|
||||
self.ale.deleteState(state_ref)
|
||||
return state
|
||||
|
||||
def restore_full_state(self, state):
|
||||
"""Restore emulator state w/ system state including pseudorandomness."""
|
||||
state_ref = self.ale.decodeState(state)
|
||||
self.ale.restoreSystemState(state_ref)
|
||||
self.ale.deleteState(state_ref)
|
||||
|
||||
|
||||
ACTION_MEANING = {
|
||||
0: "NOOP",
|
||||
1: "FIRE",
|
||||
2: "UP",
|
||||
3: "RIGHT",
|
||||
4: "LEFT",
|
||||
5: "DOWN",
|
||||
6: "UPRIGHT",
|
||||
7: "UPLEFT",
|
||||
8: "DOWNRIGHT",
|
||||
9: "DOWNLEFT",
|
||||
10: "UPFIRE",
|
||||
11: "RIGHTFIRE",
|
||||
12: "LEFTFIRE",
|
||||
13: "DOWNFIRE",
|
||||
14: "UPRIGHTFIRE",
|
||||
15: "UPLEFTFIRE",
|
||||
16: "DOWNRIGHTFIRE",
|
||||
17: "DOWNLEFTFIRE",
|
||||
}
|
@@ -26,9 +26,9 @@ def should_skip_env_spec_for_tests(spec):
|
||||
):
|
||||
return True
|
||||
try:
|
||||
import atari_py
|
||||
import ale_py
|
||||
except ImportError:
|
||||
if ep.startswith("gym.envs.atari"):
|
||||
if ep.startswith("ale_py.gym"):
|
||||
return True
|
||||
try:
|
||||
import Box2D
|
||||
@@ -40,7 +40,7 @@ def should_skip_env_spec_for_tests(spec):
|
||||
"GoEnv" in ep
|
||||
or "HexEnv" in ep
|
||||
or (
|
||||
ep.startswith("gym.envs.atari")
|
||||
ep.startswith("ale_py.gym")
|
||||
and not spec.id.startswith("Pong")
|
||||
and not spec.id.startswith("Seaquest")
|
||||
)
|
||||
|
@@ -119,12 +119,12 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
if self.grayscale_obs:
|
||||
self.ale.getScreenGrayscale(self.obs_buffer[1])
|
||||
else:
|
||||
self.ale.getScreenRGB2(self.obs_buffer[1])
|
||||
self.ale.getScreenRGB(self.obs_buffer[1])
|
||||
elif t == self.frame_skip - 1:
|
||||
if self.grayscale_obs:
|
||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||
else:
|
||||
self.ale.getScreenRGB2(self.obs_buffer[0])
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
return self._get_obs(), R, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
@@ -144,7 +144,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
if self.grayscale_obs:
|
||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||
else:
|
||||
self.ale.getScreenRGB2(self.obs_buffer[0])
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
self.obs_buffer[1].fill(0)
|
||||
return self._get_obs()
|
||||
|
||||
|
@@ -3,7 +3,7 @@ import gym
|
||||
from gym.wrappers import AtariPreprocessing
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("atari_py")
|
||||
pytest.importorskip("ale_py")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("atari_py")
|
||||
pytest.importorskip("ale_py")
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
@@ -28,23 +28,23 @@ except ImportError:
|
||||
)
|
||||
def test_frame_stack(env_id, num_stack, lz4_compress):
|
||||
env = gym.make(env_id)
|
||||
env.seed(0)
|
||||
shape = env.observation_space.shape
|
||||
env = FrameStack(env, num_stack, lz4_compress)
|
||||
assert env.observation_space.shape == (num_stack,) + shape
|
||||
assert env.observation_space.dtype == env.env.observation_space.dtype
|
||||
|
||||
dup = gym.make(env_id)
|
||||
dup.seed(0)
|
||||
|
||||
obs = env.reset()
|
||||
obs = np.asarray(obs)
|
||||
assert obs.shape == (num_stack,) + shape
|
||||
for i in range(1, num_stack):
|
||||
assert np.allclose(obs[i - 1], obs[i])
|
||||
dup_obs = dup.reset()
|
||||
assert np.allclose(obs[-1], dup_obs)
|
||||
|
||||
obs, _, _, _ = env.step(env.action_space.sample())
|
||||
obs = np.asarray(obs)
|
||||
assert obs.shape == (num_stack,) + shape
|
||||
for i in range(1, num_stack - 1):
|
||||
assert np.allclose(obs[i - 1], obs[i])
|
||||
assert not np.allclose(obs[-1], obs[-2])
|
||||
for _ in range(num_stack ** 2):
|
||||
action = env.action_space.sample()
|
||||
dup_obs, _, _, _ = dup.step(action)
|
||||
obs, _, _, _ = env.step(action)
|
||||
assert np.allclose(obs[-1], dup_obs)
|
||||
|
||||
obs, _, _, _ = env.step(env.action_space.sample())
|
||||
assert len(obs) == num_stack
|
||||
|
@@ -6,7 +6,7 @@ import gym
|
||||
from gym.wrappers import GrayScaleObservation
|
||||
from gym.wrappers import AtariPreprocessing
|
||||
|
||||
pytest.importorskip("atari_py")
|
||||
pytest.importorskip("ale_py")
|
||||
pytest.importorskip("cv2")
|
||||
|
||||
|
||||
|
@@ -4,13 +4,13 @@ import gym
|
||||
from gym.wrappers import ResizeObservation
|
||||
|
||||
try:
|
||||
import atari_py
|
||||
import ale_py
|
||||
except ImportError:
|
||||
atari_py = None
|
||||
ale_py = None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
atari_py is None, reason="Only run this test when atari_py is installed"
|
||||
ale_py is None, reason="Only run this test when ale_py is installed"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
|
||||
|
@@ -16,5 +16,6 @@ COPY . /usr/local/gym/
|
||||
WORKDIR /usr/local/gym/
|
||||
|
||||
RUN pip install -e .[nomujoco] && pip install -r test_requirements.txt
|
||||
RUN AutoROM --accept-license
|
||||
|
||||
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
||||
|
@@ -1,4 +1,4 @@
|
||||
atari-py==0.2.6
|
||||
ale-py~=0.7
|
||||
opencv-python>=3.
|
||||
box2d-py==2.3.5
|
||||
mujoco_py>=1.50, <2.0
|
||||
|
2
setup.py
2
setup.py
@@ -9,7 +9,7 @@ from version import VERSION
|
||||
|
||||
# Environment-specific dependencies.
|
||||
extras = {
|
||||
"atari": ["atari-py==0.2.6", "opencv-python>=3."],
|
||||
"atari": ["ale-py~=0.7"],
|
||||
"box2d": ["box2d-py==2.3.5", "pyglet>=1.4.0"],
|
||||
"classic_control": ["pyglet>=1.4.0"],
|
||||
"mujoco": ["mujoco_py>=1.50, <2.0"],
|
||||
|
@@ -1,3 +1,4 @@
|
||||
lz4~=3.1
|
||||
pytest~=6.2
|
||||
pytest-forked~=1.3
|
||||
AutoROM>=0.3
|
||||
|
Reference in New Issue
Block a user