Files
Gymnasium/gym/wrappers/atari_preprocessing.py

185 lines
6.5 KiB
Python
Raw Normal View History

import numpy as np
import gym
from gym.spaces import Box
2021-07-29 02:26:34 +02:00
try:
import cv2
except ImportError:
cv2 = None
class AtariPreprocessing(gym.Wrapper):
2021-07-29 02:26:34 +02:00
r"""Atari 2600 preprocessings.
2021-07-29 02:26:34 +02:00
This class follows the guidelines in
Machado et al. (2018), "Revisiting the Arcade Learning Environment:
Evaluation Protocols and Open Problems for General Agents".
Specifically:
2021-07-29 02:26:34 +02:00
* NoopReset: obtain initial state by taking random number of no-ops on reset.
* Frame skipping: 4 by default
* Max-pooling: most recent two observations
* Termination signal when a life is lost: turned off by default. Not recommended by Machado et al. (2018).
* Resize to a square image: 84x84 by default
* Grayscale observation: optional
* Scale observation: optional
Args:
env (Env): environment
noop_max (int): max number of no-ops
2021-07-29 02:26:34 +02:00
frame_skip (int): the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): if True, then step() returns done=True whenever a
2021-07-29 02:26:34 +02:00
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
grayscale_newaxis (bool): if True and grayscale_obs=True, then a channel axis is added to
grayscale observations to make them 3-dimensional.
scale_obs (bool): if True, then observation normalized in range [0,1] is returned. It also limits memory
optimization benefits of FrameStack Wrapper.
"""
2021-07-29 02:26:34 +02:00
def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
2021-07-29 02:26:34 +02:00
):
super().__init__(env)
2021-07-29 02:26:34 +02:00
assert (
cv2 is not None
), "opencv-python package not installed! Try running pip install gym[other] to get dependencies for atari"
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1:
if (
"NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one"
" frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
2021-07-29 02:26:34 +02:00
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling
if grayscale_obs:
2021-07-29 02:26:34 +02:00
self.obs_buffer = [
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
]
else:
2021-07-29 02:26:34 +02:00
self.obs_buffer = [
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
]
self.ale = env.unwrapped.ale
self.lives = 0
self.game_over = False
2021-07-29 15:39:42 -04:00
_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
2021-07-29 15:39:42 -04:00
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
def step(self, action):
R = 0.0
for t in range(self.frame_skip):
_, reward, done, info = self.env.step(action)
R += reward
self.game_over = done
if self.terminal_on_life_loss:
new_lives = self.ale.lives()
done = done or new_lives < self.lives
self.lives = new_lives
if done:
break
if t == self.frame_skip - 2:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[1])
else:
self.ale.getScreenRGB(self.obs_buffer[1])
elif t == self.frame_skip - 1:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), R, done, info
def reset(self, **kwargs):
# NoopReset
if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
2021-07-29 15:39:42 -04:00
noops = (
Seeding update (#2422) * Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
2021-12-08 22:14:15 +01:00
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
2021-07-29 15:39:42 -04:00
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, done, step_info = self.env.step(0)
reset_info.update(step_info)
if done:
if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
self.lives = self.ale.lives()
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
if kwargs.get("return_info", False):
return self._get_obs(), reset_info
else:
return self._get_obs()
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
2021-07-29 02:26:34 +02:00
obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
interpolation=cv2.INTER_AREA,
)
if self.scale_obs:
obs = np.asarray(obs, dtype=np.float32) / 255.0
else:
obs = np.asarray(obs, dtype=np.uint8)
if self.grayscale_obs and self.grayscale_newaxis:
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
return obs