2019-05-10 23:59:32 +02:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import gym
|
|
|
|
from gym.spaces import Box
|
|
|
|
from gym.wrappers import TimeLimit
|
2020-01-24 14:05:12 -08:00
|
|
|
try:
|
|
|
|
import cv2
|
|
|
|
except ImportError:
|
|
|
|
cv2 = None
|
2019-05-10 23:59:32 +02:00
|
|
|
|
|
|
|
|
|
|
|
class AtariPreprocessing(gym.Wrapper):
|
|
|
|
r"""Atari 2600 preprocessings.
|
|
|
|
|
|
|
|
This class follows the guidelines in
|
|
|
|
Machado et al. (2018), "Revisiting the Arcade Learning Environment:
|
|
|
|
Evaluation Protocols and Open Problems for General Agents".
|
|
|
|
|
|
|
|
Specifically:
|
|
|
|
|
|
|
|
* NoopReset: obtain initial state by taking random number of no-ops on reset.
|
|
|
|
* Frame skipping: 4 by default
|
|
|
|
* Max-pooling: most recent two observations
|
|
|
|
* Termination signal when a life is lost: turned off by default. Not recommended by Machado et al. (2018).
|
|
|
|
* Resize to a square image: 84x84 by default
|
|
|
|
* Grayscale observation: optional
|
2019-10-04 14:19:00 -07:00
|
|
|
* Scale observation: optional
|
2019-05-10 23:59:32 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
env (Env): environment
|
|
|
|
noop_max (int): max number of no-ops
|
|
|
|
frame_skip (int): the frequency at which the agent experiences the game.
|
|
|
|
screen_size (int): resize Atari frame
|
|
|
|
terminal_on_life_loss (bool): if True, then step() returns done=True whenever a
|
|
|
|
life is lost.
|
|
|
|
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
|
|
|
|
is returned.
|
2020-08-14 17:18:42 -04:00
|
|
|
grayscale_newaxis (bool): if True and grayscale_obs=True, then a channel axis is added to
|
|
|
|
grayscale observations to make them 3-dimensional.
|
2019-10-04 14:19:00 -07:00
|
|
|
scale_obs (bool): if True, then observation normalized in range [0,1] is returned. It also limits memory
|
|
|
|
optimization benefits of FrameStack Wrapper.
|
2019-05-10 23:59:32 +02:00
|
|
|
"""
|
2019-10-04 14:19:00 -07:00
|
|
|
|
|
|
|
def __init__(self, env, noop_max=30, frame_skip=4, screen_size=84, terminal_on_life_loss=False, grayscale_obs=True,
|
2020-08-14 17:18:42 -04:00
|
|
|
grayscale_newaxis=False, scale_obs=False):
|
2019-05-10 23:59:32 +02:00
|
|
|
super().__init__(env)
|
2020-01-24 14:05:12 -08:00
|
|
|
assert cv2 is not None, \
|
|
|
|
"opencv-python package not installed! Try running pip install gym[atari] to get dependencies for atari"
|
2019-05-10 23:59:32 +02:00
|
|
|
assert frame_skip > 0
|
|
|
|
assert screen_size > 0
|
2019-11-04 09:48:21 -08:00
|
|
|
assert noop_max >= 0
|
|
|
|
if frame_skip > 1:
|
|
|
|
assert 'NoFrameskip' in env.spec.id, 'disable frame-skipping in the original env. for more than one' \
|
|
|
|
' frame-skip as it will be done by the wrapper'
|
2019-05-10 23:59:32 +02:00
|
|
|
self.noop_max = noop_max
|
|
|
|
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
|
|
|
|
|
|
|
self.frame_skip = frame_skip
|
|
|
|
self.screen_size = screen_size
|
|
|
|
self.terminal_on_life_loss = terminal_on_life_loss
|
|
|
|
self.grayscale_obs = grayscale_obs
|
2020-08-14 17:18:42 -04:00
|
|
|
self.grayscale_newaxis = grayscale_newaxis
|
2019-10-04 14:19:00 -07:00
|
|
|
self.scale_obs = scale_obs
|
2019-05-10 23:59:32 +02:00
|
|
|
|
|
|
|
# buffer of most recent two observations for max pooling
|
|
|
|
if grayscale_obs:
|
|
|
|
self.obs_buffer = [np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
|
|
|
np.empty(env.observation_space.shape[:2], dtype=np.uint8)]
|
|
|
|
else:
|
|
|
|
self.obs_buffer = [np.empty(env.observation_space.shape, dtype=np.uint8),
|
|
|
|
np.empty(env.observation_space.shape, dtype=np.uint8)]
|
|
|
|
|
|
|
|
self.ale = env.unwrapped.ale
|
|
|
|
self.lives = 0
|
|
|
|
self.game_over = False
|
|
|
|
|
2019-10-04 14:19:00 -07:00
|
|
|
_low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
|
2020-08-14 17:18:42 -04:00
|
|
|
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
|
|
|
|
if grayscale_obs and not grayscale_newaxis:
|
|
|
|
_shape = _shape[:-1] # Remove channel axis
|
|
|
|
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
|
2019-05-10 23:59:32 +02:00
|
|
|
|
|
|
|
def step(self, action):
|
|
|
|
R = 0.0
|
|
|
|
|
|
|
|
for t in range(self.frame_skip):
|
|
|
|
_, reward, done, info = self.env.step(action)
|
|
|
|
R += reward
|
|
|
|
self.game_over = done
|
|
|
|
|
|
|
|
if self.terminal_on_life_loss:
|
|
|
|
new_lives = self.ale.lives()
|
|
|
|
done = done or new_lives < self.lives
|
|
|
|
self.lives = new_lives
|
|
|
|
|
|
|
|
if done:
|
2019-10-04 14:19:00 -07:00
|
|
|
break
|
2019-05-10 23:59:32 +02:00
|
|
|
if t == self.frame_skip - 2:
|
|
|
|
if self.grayscale_obs:
|
2020-04-17 19:13:06 -03:00
|
|
|
self.ale.getScreenGrayscale(self.obs_buffer[1])
|
2019-05-10 23:59:32 +02:00
|
|
|
else:
|
2020-04-17 19:13:06 -03:00
|
|
|
self.ale.getScreenRGB2(self.obs_buffer[1])
|
2019-05-10 23:59:32 +02:00
|
|
|
elif t == self.frame_skip - 1:
|
|
|
|
if self.grayscale_obs:
|
2020-04-17 19:13:06 -03:00
|
|
|
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
2019-05-10 23:59:32 +02:00
|
|
|
else:
|
2020-04-17 19:13:06 -03:00
|
|
|
self.ale.getScreenRGB2(self.obs_buffer[0])
|
2019-05-10 23:59:32 +02:00
|
|
|
return self._get_obs(), R, done, info
|
|
|
|
|
|
|
|
def reset(self, **kwargs):
|
|
|
|
# NoopReset
|
|
|
|
self.env.reset(**kwargs)
|
2019-08-23 15:45:55 -07:00
|
|
|
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
|
2019-05-10 23:59:32 +02:00
|
|
|
for _ in range(noops):
|
|
|
|
_, _, done, _ = self.env.step(0)
|
|
|
|
if done:
|
|
|
|
self.env.reset(**kwargs)
|
|
|
|
|
|
|
|
self.lives = self.ale.lives()
|
|
|
|
if self.grayscale_obs:
|
|
|
|
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
|
|
|
else:
|
|
|
|
self.ale.getScreenRGB2(self.obs_buffer[0])
|
|
|
|
self.obs_buffer[1].fill(0)
|
|
|
|
return self._get_obs()
|
|
|
|
|
|
|
|
def _get_obs(self):
|
|
|
|
if self.frame_skip > 1: # more efficient in-place pooling
|
|
|
|
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
|
|
|
|
obs = cv2.resize(self.obs_buffer[0], (self.screen_size, self.screen_size), interpolation=cv2.INTER_AREA)
|
2019-10-04 14:19:00 -07:00
|
|
|
|
|
|
|
if self.scale_obs:
|
|
|
|
obs = np.asarray(obs, dtype=np.float32) / 255.0
|
|
|
|
else:
|
|
|
|
obs = np.asarray(obs, dtype=np.uint8)
|
2020-08-14 17:18:42 -04:00
|
|
|
|
|
|
|
if self.grayscale_obs and self.grayscale_newaxis:
|
|
|
|
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
|
2019-05-10 23:59:32 +02:00
|
|
|
return obs
|