mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Remove AtariEnv in favour of official ALE Python package (#2348)
* Remove AtariEnv in favour of official ALE Python * More robust frame stacking test case * Atari documentation update
This commit is contained in:
committed by
GitHub
parent
263a3419ef
commit
f6742ea808
@@ -14,22 +14,27 @@ It's worth browsing through both.
|
|||||||
|
|
||||||
### Atari
|
### Atari
|
||||||
|
|
||||||
The Atari environments are a variety of Atari video games. If you didn't
|
The Atari environments are provided by the [Arcade Learning Environment](https://github.com/mgbellemare/Arcade-Learning-Environment) (ALE).
|
||||||
do the full install, you can install dependencies via `pip install -e
|
If you didn't do the full install, you can install dependencies via `pip install -e '.[atari]'` which will install `ale-py`.
|
||||||
'.[atari]'` (you'll need `cmake` installed) and then get started as
|
You can then create any of the legacy Atari environments as such:
|
||||||
follows:
|
|
||||||
|
|
||||||
``` python
|
``` python
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
env = gym.make('SpaceInvaders-v4')
|
env = gym.make('SpaceInvaders-v4')
|
||||||
env.reset()
|
|
||||||
env.render()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This will install `atari-py`, which automatically compiles the [Arcade
|
Newer versions of the Atari environments live in the [ALE](https://github.com/mgbellemare/Arcade-Learning-Environment) repository
|
||||||
Learning Environment](https://github.com/mgbellemare/Arcade-Learning-Environment#:~:text=The%20Arcade%20Learning%20Environment%20(ALE)%20is%20a%20simple%20object%2D,of%20emulation%20from%20agent%20design.). This
|
and are namespaced with `ALE`. For example, the `v5` environments can be included as such:
|
||||||
can take quite a while (a few minutes on a decent laptop), so just be
|
|
||||||
prepared.
|
```python
|
||||||
|
import gym
|
||||||
|
import ale_py
|
||||||
|
|
||||||
|
env = gym.make('ALE/SpaceInvaders-v5')
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: ROMs are not distributed by the ALE but tools are provided in the ALE to manage ROMs.
|
||||||
|
Please see the project's documentation on importing ROMs.
|
||||||
|
|
||||||
### Box2d
|
### Box2d
|
||||||
|
|
||||||
|
@@ -613,7 +613,6 @@ for reward_type in ["sparse", "dense"]:
|
|||||||
# Atari
|
# Atari
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
|
|
||||||
# # print ', '.join(["'{}'".format(name.split('.')[0]) for name in atari_py.list_games()])
|
|
||||||
for game in [
|
for game in [
|
||||||
"adventure",
|
"adventure",
|
||||||
"air_raid",
|
"air_raid",
|
||||||
@@ -678,7 +677,7 @@ for game in [
|
|||||||
"yars_revenge",
|
"yars_revenge",
|
||||||
"zaxxon",
|
"zaxxon",
|
||||||
]:
|
]:
|
||||||
for obs_type in ["image", "ram"]:
|
for obs_type in ["rgb", "ram"]:
|
||||||
# space_invaders should yield SpaceInvaders-v0 and SpaceInvaders-ram-v0
|
# space_invaders should yield SpaceInvaders-v0 and SpaceInvaders-ram-v0
|
||||||
name = "".join([g.capitalize() for g in game.split("_")])
|
name = "".join([g.capitalize() for g in game.split("_")])
|
||||||
if obs_type == "ram":
|
if obs_type == "ram":
|
||||||
@@ -694,7 +693,7 @@ for game in [
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id="{}-v0".format(name),
|
id="{}-v0".format(name),
|
||||||
entry_point="gym.envs.atari:AtariEnv",
|
entry_point="ale_py.gym:ALGymEnv",
|
||||||
kwargs={
|
kwargs={
|
||||||
"game": game,
|
"game": game,
|
||||||
"obs_type": obs_type,
|
"obs_type": obs_type,
|
||||||
@@ -706,7 +705,7 @@ for game in [
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id="{}-v4".format(name),
|
id="{}-v4".format(name),
|
||||||
entry_point="gym.envs.atari:AtariEnv",
|
entry_point="ale_py.gym:ALGymEnv",
|
||||||
kwargs={"game": game, "obs_type": obs_type},
|
kwargs={"game": game, "obs_type": obs_type},
|
||||||
max_episode_steps=100000,
|
max_episode_steps=100000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
@@ -721,7 +720,7 @@ for game in [
|
|||||||
# Use a deterministic frame skip.
|
# Use a deterministic frame skip.
|
||||||
register(
|
register(
|
||||||
id="{}Deterministic-v0".format(name),
|
id="{}Deterministic-v0".format(name),
|
||||||
entry_point="gym.envs.atari:AtariEnv",
|
entry_point="ale_py.gym:ALGymEnv",
|
||||||
kwargs={
|
kwargs={
|
||||||
"game": game,
|
"game": game,
|
||||||
"obs_type": obs_type,
|
"obs_type": obs_type,
|
||||||
@@ -734,7 +733,7 @@ for game in [
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id="{}Deterministic-v4".format(name),
|
id="{}Deterministic-v4".format(name),
|
||||||
entry_point="gym.envs.atari:AtariEnv",
|
entry_point="ale_py.gym:ALGymEnv",
|
||||||
kwargs={"game": game, "obs_type": obs_type, "frameskip": frameskip},
|
kwargs={"game": game, "obs_type": obs_type, "frameskip": frameskip},
|
||||||
max_episode_steps=100000,
|
max_episode_steps=100000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
@@ -742,7 +741,7 @@ for game in [
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id="{}NoFrameskip-v0".format(name),
|
id="{}NoFrameskip-v0".format(name),
|
||||||
entry_point="gym.envs.atari:AtariEnv",
|
entry_point="ale_py.gym:ALGymEnv",
|
||||||
kwargs={
|
kwargs={
|
||||||
"game": game,
|
"game": game,
|
||||||
"obs_type": obs_type,
|
"obs_type": obs_type,
|
||||||
@@ -757,7 +756,7 @@ for game in [
|
|||||||
# deterministic environments.)
|
# deterministic environments.)
|
||||||
register(
|
register(
|
||||||
id="{}NoFrameskip-v4".format(name),
|
id="{}NoFrameskip-v4".format(name),
|
||||||
entry_point="gym.envs.atari:AtariEnv",
|
entry_point="ale_py.gym:ALGymEnv",
|
||||||
kwargs={
|
kwargs={
|
||||||
"game": game,
|
"game": game,
|
||||||
"obs_type": obs_type,
|
"obs_type": obs_type,
|
||||||
|
@@ -1 +0,0 @@
|
|||||||
from gym.envs.atari.atari_env import AtariEnv
|
|
@@ -1,253 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import gym
|
|
||||||
from gym import error, spaces
|
|
||||||
from gym import utils
|
|
||||||
from gym.utils import seeding
|
|
||||||
|
|
||||||
try:
|
|
||||||
import atari_py
|
|
||||||
except ImportError as e:
|
|
||||||
raise error.DependencyNotInstalled(
|
|
||||||
"{}. (HINT: you can install Atari dependencies by running "
|
|
||||||
"'pip install gym[atari]'.)".format(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def to_ram(ale):
|
|
||||||
ram_size = ale.getRAMSize()
|
|
||||||
ram = np.zeros((ram_size), dtype=np.uint8)
|
|
||||||
ale.getRAM(ram)
|
|
||||||
return ram
|
|
||||||
|
|
||||||
|
|
||||||
class AtariEnv(gym.Env, utils.EzPickle):
|
|
||||||
metadata = {"render.modes": ["human", "rgb_array"]}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
game="pong",
|
|
||||||
mode=None,
|
|
||||||
difficulty=None,
|
|
||||||
obs_type="ram",
|
|
||||||
frameskip=(2, 5),
|
|
||||||
repeat_action_probability=0.0,
|
|
||||||
full_action_space=False,
|
|
||||||
):
|
|
||||||
"""Frameskip should be either a tuple (indicating a random range to
|
|
||||||
choose from, with the top value exclude), or an int."""
|
|
||||||
|
|
||||||
utils.EzPickle.__init__(
|
|
||||||
self,
|
|
||||||
game,
|
|
||||||
mode,
|
|
||||||
difficulty,
|
|
||||||
obs_type,
|
|
||||||
frameskip,
|
|
||||||
repeat_action_probability,
|
|
||||||
full_action_space,
|
|
||||||
)
|
|
||||||
assert obs_type in ("ram", "image")
|
|
||||||
|
|
||||||
self.game = game
|
|
||||||
self.game_path = atari_py.get_game_path(game)
|
|
||||||
self.game_mode = mode
|
|
||||||
self.game_difficulty = difficulty
|
|
||||||
|
|
||||||
if not os.path.exists(self.game_path):
|
|
||||||
msg = "You asked for game %s but path %s does not exist"
|
|
||||||
raise IOError(msg % (game, self.game_path))
|
|
||||||
self._obs_type = obs_type
|
|
||||||
self.frameskip = frameskip
|
|
||||||
self.ale = atari_py.ALEInterface()
|
|
||||||
self.viewer = None
|
|
||||||
|
|
||||||
# Tune (or disable) ALE's action repeat:
|
|
||||||
# https://github.com/openai/gym/issues/349
|
|
||||||
assert isinstance(
|
|
||||||
repeat_action_probability, (float, int)
|
|
||||||
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
|
|
||||||
self.ale.setFloat(
|
|
||||||
"repeat_action_probability".encode("utf-8"), repeat_action_probability
|
|
||||||
)
|
|
||||||
|
|
||||||
self.seed()
|
|
||||||
|
|
||||||
self._action_set = (
|
|
||||||
self.ale.getLegalActionSet()
|
|
||||||
if full_action_space
|
|
||||||
else self.ale.getMinimalActionSet()
|
|
||||||
)
|
|
||||||
self.action_space = spaces.Discrete(len(self._action_set))
|
|
||||||
|
|
||||||
(screen_width, screen_height) = self.ale.getScreenDims()
|
|
||||||
if self._obs_type == "ram":
|
|
||||||
self.observation_space = spaces.Box(
|
|
||||||
low=0, high=255, dtype=np.uint8, shape=(128,)
|
|
||||||
)
|
|
||||||
elif self._obs_type == "image":
|
|
||||||
self.observation_space = spaces.Box(
|
|
||||||
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise error.Error(
|
|
||||||
"Unrecognized observation type: {}".format(self._obs_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed1 = seeding.np_random(seed)
|
|
||||||
# Derive a random seed. This gets passed as a uint, but gets
|
|
||||||
# checked as an int elsewhere, so we need to keep it below
|
|
||||||
# 2**31.
|
|
||||||
seed2 = seeding.hash_seed(seed1 + 1) % 2 ** 31
|
|
||||||
# Empirically, we need to seed before loading the ROM.
|
|
||||||
self.ale.setInt(b"random_seed", seed2)
|
|
||||||
self.ale.loadROM(self.game_path)
|
|
||||||
|
|
||||||
if self.game_mode is not None:
|
|
||||||
modes = self.ale.getAvailableModes()
|
|
||||||
|
|
||||||
assert self.game_mode in modes, (
|
|
||||||
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}'
|
|
||||||
).format(self.game_mode, self.game, modes)
|
|
||||||
self.ale.setMode(self.game_mode)
|
|
||||||
|
|
||||||
if self.game_difficulty is not None:
|
|
||||||
difficulties = self.ale.getAvailableDifficulties()
|
|
||||||
|
|
||||||
assert self.game_difficulty in difficulties, (
|
|
||||||
'Invalid game difficulty "{}" for game {}.\nAvailable difficulties are: {}'
|
|
||||||
).format(self.game_difficulty, self.game, difficulties)
|
|
||||||
self.ale.setDifficulty(self.game_difficulty)
|
|
||||||
|
|
||||||
return [seed1, seed2]
|
|
||||||
|
|
||||||
def step(self, a):
|
|
||||||
reward = 0.0
|
|
||||||
action = self._action_set[a]
|
|
||||||
|
|
||||||
if isinstance(self.frameskip, int):
|
|
||||||
num_steps = self.frameskip
|
|
||||||
else:
|
|
||||||
num_steps = self.np_random.randint(self.frameskip[0], self.frameskip[1])
|
|
||||||
for _ in range(num_steps):
|
|
||||||
reward += self.ale.act(action)
|
|
||||||
ob = self._get_obs()
|
|
||||||
|
|
||||||
return ob, reward, self.ale.game_over(), {"ale.lives": self.ale.lives()}
|
|
||||||
|
|
||||||
def _get_image(self):
|
|
||||||
return self.ale.getScreenRGB2()
|
|
||||||
|
|
||||||
def _get_ram(self):
|
|
||||||
return to_ram(self.ale)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _n_actions(self):
|
|
||||||
return len(self._action_set)
|
|
||||||
|
|
||||||
def _get_obs(self):
|
|
||||||
if self._obs_type == "ram":
|
|
||||||
return self._get_ram()
|
|
||||||
elif self._obs_type == "image":
|
|
||||||
img = self._get_image()
|
|
||||||
return img
|
|
||||||
|
|
||||||
# return: (states, observations)
|
|
||||||
def reset(self):
|
|
||||||
self.ale.reset_game()
|
|
||||||
return self._get_obs()
|
|
||||||
|
|
||||||
def render(self, mode="human"):
|
|
||||||
img = self._get_image()
|
|
||||||
if mode == "rgb_array":
|
|
||||||
return img
|
|
||||||
elif mode == "human":
|
|
||||||
from gym.envs.classic_control import rendering
|
|
||||||
|
|
||||||
if self.viewer is None:
|
|
||||||
self.viewer = rendering.SimpleImageViewer()
|
|
||||||
self.viewer.imshow(img)
|
|
||||||
return self.viewer.isopen
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
if self.viewer is not None:
|
|
||||||
self.viewer.close()
|
|
||||||
self.viewer = None
|
|
||||||
|
|
||||||
def get_action_meanings(self):
|
|
||||||
return [ACTION_MEANING[i] for i in self._action_set]
|
|
||||||
|
|
||||||
def get_keys_to_action(self):
|
|
||||||
KEYWORD_TO_KEY = {
|
|
||||||
"UP": ord("w"),
|
|
||||||
"DOWN": ord("s"),
|
|
||||||
"LEFT": ord("a"),
|
|
||||||
"RIGHT": ord("d"),
|
|
||||||
"FIRE": ord(" "),
|
|
||||||
}
|
|
||||||
|
|
||||||
keys_to_action = {}
|
|
||||||
|
|
||||||
for action_id, action_meaning in enumerate(self.get_action_meanings()):
|
|
||||||
keys = []
|
|
||||||
for keyword, key in KEYWORD_TO_KEY.items():
|
|
||||||
if keyword in action_meaning:
|
|
||||||
keys.append(key)
|
|
||||||
keys = tuple(sorted(keys))
|
|
||||||
|
|
||||||
assert keys not in keys_to_action
|
|
||||||
keys_to_action[keys] = action_id
|
|
||||||
|
|
||||||
return keys_to_action
|
|
||||||
|
|
||||||
def clone_state(self):
|
|
||||||
"""Clone emulator state w/o system state. Restoring this state will
|
|
||||||
*not* give an identical environment. For complete cloning and restoring
|
|
||||||
of the full state, see `{clone,restore}_full_state()`."""
|
|
||||||
state_ref = self.ale.cloneState()
|
|
||||||
state = self.ale.encodeState(state_ref)
|
|
||||||
self.ale.deleteState(state_ref)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def restore_state(self, state):
|
|
||||||
"""Restore emulator state w/o system state."""
|
|
||||||
state_ref = self.ale.decodeState(state)
|
|
||||||
self.ale.restoreState(state_ref)
|
|
||||||
self.ale.deleteState(state_ref)
|
|
||||||
|
|
||||||
def clone_full_state(self):
|
|
||||||
"""Clone emulator state w/ system state including pseudorandomness.
|
|
||||||
Restoring this state will give an identical environment."""
|
|
||||||
state_ref = self.ale.cloneSystemState()
|
|
||||||
state = self.ale.encodeState(state_ref)
|
|
||||||
self.ale.deleteState(state_ref)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def restore_full_state(self, state):
|
|
||||||
"""Restore emulator state w/ system state including pseudorandomness."""
|
|
||||||
state_ref = self.ale.decodeState(state)
|
|
||||||
self.ale.restoreSystemState(state_ref)
|
|
||||||
self.ale.deleteState(state_ref)
|
|
||||||
|
|
||||||
|
|
||||||
ACTION_MEANING = {
|
|
||||||
0: "NOOP",
|
|
||||||
1: "FIRE",
|
|
||||||
2: "UP",
|
|
||||||
3: "RIGHT",
|
|
||||||
4: "LEFT",
|
|
||||||
5: "DOWN",
|
|
||||||
6: "UPRIGHT",
|
|
||||||
7: "UPLEFT",
|
|
||||||
8: "DOWNRIGHT",
|
|
||||||
9: "DOWNLEFT",
|
|
||||||
10: "UPFIRE",
|
|
||||||
11: "RIGHTFIRE",
|
|
||||||
12: "LEFTFIRE",
|
|
||||||
13: "DOWNFIRE",
|
|
||||||
14: "UPRIGHTFIRE",
|
|
||||||
15: "UPLEFTFIRE",
|
|
||||||
16: "DOWNRIGHTFIRE",
|
|
||||||
17: "DOWNLEFTFIRE",
|
|
||||||
}
|
|
@@ -26,9 +26,9 @@ def should_skip_env_spec_for_tests(spec):
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
import atari_py
|
import ale_py
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if ep.startswith("gym.envs.atari"):
|
if ep.startswith("ale_py.gym"):
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
import Box2D
|
import Box2D
|
||||||
@@ -40,7 +40,7 @@ def should_skip_env_spec_for_tests(spec):
|
|||||||
"GoEnv" in ep
|
"GoEnv" in ep
|
||||||
or "HexEnv" in ep
|
or "HexEnv" in ep
|
||||||
or (
|
or (
|
||||||
ep.startswith("gym.envs.atari")
|
ep.startswith("ale_py.gym")
|
||||||
and not spec.id.startswith("Pong")
|
and not spec.id.startswith("Pong")
|
||||||
and not spec.id.startswith("Seaquest")
|
and not spec.id.startswith("Seaquest")
|
||||||
)
|
)
|
||||||
|
@@ -119,12 +119,12 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
if self.grayscale_obs:
|
if self.grayscale_obs:
|
||||||
self.ale.getScreenGrayscale(self.obs_buffer[1])
|
self.ale.getScreenGrayscale(self.obs_buffer[1])
|
||||||
else:
|
else:
|
||||||
self.ale.getScreenRGB2(self.obs_buffer[1])
|
self.ale.getScreenRGB(self.obs_buffer[1])
|
||||||
elif t == self.frame_skip - 1:
|
elif t == self.frame_skip - 1:
|
||||||
if self.grayscale_obs:
|
if self.grayscale_obs:
|
||||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||||
else:
|
else:
|
||||||
self.ale.getScreenRGB2(self.obs_buffer[0])
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
return self._get_obs(), R, done, info
|
return self._get_obs(), R, done, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
@@ -144,7 +144,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
if self.grayscale_obs:
|
if self.grayscale_obs:
|
||||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||||
else:
|
else:
|
||||||
self.ale.getScreenRGB2(self.obs_buffer[0])
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
self.obs_buffer[1].fill(0)
|
self.obs_buffer[1].fill(0)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@@ -3,7 +3,7 @@ import gym
|
|||||||
from gym.wrappers import AtariPreprocessing
|
from gym.wrappers import AtariPreprocessing
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
pytest.importorskip("atari_py")
|
pytest.importorskip("ale_py")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
pytest.importorskip("atari_py")
|
pytest.importorskip("ale_py")
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
@@ -28,23 +28,23 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
def test_frame_stack(env_id, num_stack, lz4_compress):
|
def test_frame_stack(env_id, num_stack, lz4_compress):
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
|
env.seed(0)
|
||||||
shape = env.observation_space.shape
|
shape = env.observation_space.shape
|
||||||
env = FrameStack(env, num_stack, lz4_compress)
|
env = FrameStack(env, num_stack, lz4_compress)
|
||||||
assert env.observation_space.shape == (num_stack,) + shape
|
assert env.observation_space.shape == (num_stack,) + shape
|
||||||
assert env.observation_space.dtype == env.env.observation_space.dtype
|
assert env.observation_space.dtype == env.env.observation_space.dtype
|
||||||
|
|
||||||
|
dup = gym.make(env_id)
|
||||||
|
dup.seed(0)
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
obs = np.asarray(obs)
|
dup_obs = dup.reset()
|
||||||
assert obs.shape == (num_stack,) + shape
|
assert np.allclose(obs[-1], dup_obs)
|
||||||
for i in range(1, num_stack):
|
|
||||||
assert np.allclose(obs[i - 1], obs[i])
|
|
||||||
|
|
||||||
obs, _, _, _ = env.step(env.action_space.sample())
|
for _ in range(num_stack ** 2):
|
||||||
obs = np.asarray(obs)
|
action = env.action_space.sample()
|
||||||
assert obs.shape == (num_stack,) + shape
|
dup_obs, _, _, _ = dup.step(action)
|
||||||
for i in range(1, num_stack - 1):
|
obs, _, _, _ = env.step(action)
|
||||||
assert np.allclose(obs[i - 1], obs[i])
|
assert np.allclose(obs[-1], dup_obs)
|
||||||
assert not np.allclose(obs[-1], obs[-2])
|
|
||||||
|
|
||||||
obs, _, _, _ = env.step(env.action_space.sample())
|
|
||||||
assert len(obs) == num_stack
|
assert len(obs) == num_stack
|
||||||
|
@@ -6,7 +6,7 @@ import gym
|
|||||||
from gym.wrappers import GrayScaleObservation
|
from gym.wrappers import GrayScaleObservation
|
||||||
from gym.wrappers import AtariPreprocessing
|
from gym.wrappers import AtariPreprocessing
|
||||||
|
|
||||||
pytest.importorskip("atari_py")
|
pytest.importorskip("ale_py")
|
||||||
pytest.importorskip("cv2")
|
pytest.importorskip("cv2")
|
||||||
|
|
||||||
|
|
||||||
|
@@ -4,13 +4,13 @@ import gym
|
|||||||
from gym.wrappers import ResizeObservation
|
from gym.wrappers import ResizeObservation
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import atari_py
|
import ale_py
|
||||||
except ImportError:
|
except ImportError:
|
||||||
atari_py = None
|
ale_py = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
atari_py is None, reason="Only run this test when atari_py is installed"
|
ale_py is None, reason="Only run this test when ale_py is installed"
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
|
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
|
||||||
|
@@ -16,5 +16,6 @@ COPY . /usr/local/gym/
|
|||||||
WORKDIR /usr/local/gym/
|
WORKDIR /usr/local/gym/
|
||||||
|
|
||||||
RUN pip install -e .[nomujoco] && pip install -r test_requirements.txt
|
RUN pip install -e .[nomujoco] && pip install -r test_requirements.txt
|
||||||
|
RUN AutoROM --accept-license
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
atari-py==0.2.6
|
ale-py~=0.7
|
||||||
opencv-python>=3.
|
opencv-python>=3.
|
||||||
box2d-py==2.3.5
|
box2d-py==2.3.5
|
||||||
mujoco_py>=1.50, <2.0
|
mujoco_py>=1.50, <2.0
|
||||||
|
2
setup.py
2
setup.py
@@ -9,7 +9,7 @@ from version import VERSION
|
|||||||
|
|
||||||
# Environment-specific dependencies.
|
# Environment-specific dependencies.
|
||||||
extras = {
|
extras = {
|
||||||
"atari": ["atari-py==0.2.6", "opencv-python>=3."],
|
"atari": ["ale-py~=0.7"],
|
||||||
"box2d": ["box2d-py==2.3.5", "pyglet>=1.4.0"],
|
"box2d": ["box2d-py==2.3.5", "pyglet>=1.4.0"],
|
||||||
"classic_control": ["pyglet>=1.4.0"],
|
"classic_control": ["pyglet>=1.4.0"],
|
||||||
"mujoco": ["mujoco_py>=1.50, <2.0"],
|
"mujoco": ["mujoco_py>=1.50, <2.0"],
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
lz4~=3.1
|
lz4~=3.1
|
||||||
pytest~=6.2
|
pytest~=6.2
|
||||||
pytest-forked~=1.3
|
pytest-forked~=1.3
|
||||||
|
AutoROM>=0.3
|
||||||
|
Reference in New Issue
Block a user