2019-05-10 23:59:32 +02:00
import numpy as np
2021-07-28 22:21:47 -04:00
import warnings
2019-05-10 23:59:32 +02:00
import gym
from gym . spaces import Box
from gym . wrappers import TimeLimit
2021-07-29 02:26:34 +02:00
2020-01-24 14:05:12 -08:00
try :
import cv2
except ImportError :
cv2 = None
2019-05-10 23:59:32 +02:00
class AtariPreprocessing ( gym . Wrapper ) :
2021-07-29 02:26:34 +02:00
r """ Atari 2600 preprocessings.
2019-05-10 23:59:32 +02:00
2021-07-29 02:26:34 +02:00
This class follows the guidelines in
Machado et al . ( 2018 ) , " Revisiting the Arcade Learning Environment:
2019-05-10 23:59:32 +02:00
Evaluation Protocols and Open Problems for General Agents " .
Specifically :
2021-07-29 02:26:34 +02:00
* NoopReset : obtain initial state by taking random number of no - ops on reset .
2019-05-10 23:59:32 +02:00
* Frame skipping : 4 by default
* Max - pooling : most recent two observations
* Termination signal when a life is lost : turned off by default . Not recommended by Machado et al . ( 2018 ) .
* Resize to a square image : 84 x84 by default
* Grayscale observation : optional
2019-10-04 14:19:00 -07:00
* Scale observation : optional
2019-05-10 23:59:32 +02:00
Args :
env ( Env ) : environment
noop_max ( int ) : max number of no - ops
2021-07-29 02:26:34 +02:00
frame_skip ( int ) : the frequency at which the agent experiences the game .
2019-05-10 23:59:32 +02:00
screen_size ( int ) : resize Atari frame
terminal_on_life_loss ( bool ) : if True , then step ( ) returns done = True whenever a
2021-07-29 02:26:34 +02:00
life is lost .
2019-05-10 23:59:32 +02:00
grayscale_obs ( bool ) : if True , then gray scale observation is returned , otherwise , RGB observation
is returned .
2020-08-14 17:18:42 -04:00
grayscale_newaxis ( bool ) : if True and grayscale_obs = True , then a channel axis is added to
grayscale observations to make them 3 - dimensional .
2019-10-04 14:19:00 -07:00
scale_obs ( bool ) : if True , then observation normalized in range [ 0 , 1 ] is returned . It also limits memory
optimization benefits of FrameStack Wrapper .
2019-05-10 23:59:32 +02:00
"""
2019-10-04 14:19:00 -07:00
2021-07-29 02:26:34 +02:00
def __init__ (
self ,
env ,
noop_max = 30 ,
frame_skip = 4 ,
screen_size = 84 ,
terminal_on_life_loss = False ,
grayscale_obs = True ,
grayscale_newaxis = False ,
scale_obs = False ,
) :
2019-05-10 23:59:32 +02:00
super ( ) . __init__ ( env )
2021-07-29 02:26:34 +02:00
assert (
cv2 is not None
) , " opencv-python package not installed! Try running pip install gym[atari] to get dependencies for atari "
2019-05-10 23:59:32 +02:00
assert frame_skip > 0
assert screen_size > 0
2019-11-04 09:48:21 -08:00
assert noop_max > = 0
if frame_skip > 1 :
2021-07-29 02:26:34 +02:00
assert " NoFrameskip " in env . spec . id , (
2021-07-29 12:42:48 -04:00
" disable frame-skipping in the original env. for more than one " " frame-skip as it will be done by the wrapper "
2021-07-29 02:26:34 +02:00
)
2019-05-10 23:59:32 +02:00
self . noop_max = noop_max
2021-07-29 02:26:34 +02:00
assert env . unwrapped . get_action_meanings ( ) [ 0 ] == " NOOP "
2021-07-29 12:42:48 -04:00
warnings . warn (
" Gym ' s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit "
)
2019-05-10 23:59:32 +02:00
self . frame_skip = frame_skip
self . screen_size = screen_size
self . terminal_on_life_loss = terminal_on_life_loss
self . grayscale_obs = grayscale_obs
2020-08-14 17:18:42 -04:00
self . grayscale_newaxis = grayscale_newaxis
2019-10-04 14:19:00 -07:00
self . scale_obs = scale_obs
2019-05-10 23:59:32 +02:00
# buffer of most recent two observations for max pooling
if grayscale_obs :
2021-07-29 02:26:34 +02:00
self . obs_buffer = [
np . empty ( env . observation_space . shape [ : 2 ] , dtype = np . uint8 ) ,
np . empty ( env . observation_space . shape [ : 2 ] , dtype = np . uint8 ) ,
]
2019-05-10 23:59:32 +02:00
else :
2021-07-29 02:26:34 +02:00
self . obs_buffer = [
np . empty ( env . observation_space . shape , dtype = np . uint8 ) ,
np . empty ( env . observation_space . shape , dtype = np . uint8 ) ,
]
2019-05-10 23:59:32 +02:00
self . ale = env . unwrapped . ale
self . lives = 0
self . game_over = False
2021-07-29 12:42:48 -04:00
_low , _high , _obs_dtype = ( 0 , 255 , np . uint8 ) if not scale_obs else ( 0 , 1 , np . float32 )
2020-08-14 17:18:42 -04:00
_shape = ( screen_size , screen_size , 1 if grayscale_obs else 3 )
if grayscale_obs and not grayscale_newaxis :
_shape = _shape [ : - 1 ] # Remove channel axis
2021-07-29 12:42:48 -04:00
self . observation_space = Box ( low = _low , high = _high , shape = _shape , dtype = _obs_dtype )
2019-05-10 23:59:32 +02:00
def step ( self , action ) :
R = 0.0
for t in range ( self . frame_skip ) :
_ , reward , done , info = self . env . step ( action )
R + = reward
self . game_over = done
if self . terminal_on_life_loss :
new_lives = self . ale . lives ( )
done = done or new_lives < self . lives
self . lives = new_lives
if done :
2019-10-04 14:19:00 -07:00
break
2019-05-10 23:59:32 +02:00
if t == self . frame_skip - 2 :
if self . grayscale_obs :
2020-04-17 19:13:06 -03:00
self . ale . getScreenGrayscale ( self . obs_buffer [ 1 ] )
2019-05-10 23:59:32 +02:00
else :
2020-04-17 19:13:06 -03:00
self . ale . getScreenRGB2 ( self . obs_buffer [ 1 ] )
2019-05-10 23:59:32 +02:00
elif t == self . frame_skip - 1 :
if self . grayscale_obs :
2020-04-17 19:13:06 -03:00
self . ale . getScreenGrayscale ( self . obs_buffer [ 0 ] )
2019-05-10 23:59:32 +02:00
else :
2020-04-17 19:13:06 -03:00
self . ale . getScreenRGB2 ( self . obs_buffer [ 0 ] )
2019-05-10 23:59:32 +02:00
return self . _get_obs ( ) , R , done , info
def reset ( self , * * kwargs ) :
# NoopReset
self . env . reset ( * * kwargs )
2021-07-29 12:42:48 -04:00
noops = self . env . unwrapped . np_random . randint ( 1 , self . noop_max + 1 ) if self . noop_max > 0 else 0
2019-05-10 23:59:32 +02:00
for _ in range ( noops ) :
_ , _ , done , _ = self . env . step ( 0 )
if done :
self . env . reset ( * * kwargs )
self . lives = self . ale . lives ( )
if self . grayscale_obs :
self . ale . getScreenGrayscale ( self . obs_buffer [ 0 ] )
else :
self . ale . getScreenRGB2 ( self . obs_buffer [ 0 ] )
self . obs_buffer [ 1 ] . fill ( 0 )
return self . _get_obs ( )
def _get_obs ( self ) :
if self . frame_skip > 1 : # more efficient in-place pooling
np . maximum ( self . obs_buffer [ 0 ] , self . obs_buffer [ 1 ] , out = self . obs_buffer [ 0 ] )
2021-07-29 02:26:34 +02:00
obs = cv2 . resize (
self . obs_buffer [ 0 ] ,
( self . screen_size , self . screen_size ) ,
interpolation = cv2 . INTER_AREA ,
)
2019-10-04 14:19:00 -07:00
if self . scale_obs :
obs = np . asarray ( obs , dtype = np . float32 ) / 255.0
else :
obs = np . asarray ( obs , dtype = np . uint8 )
2020-08-14 17:18:42 -04:00
if self . grayscale_obs and self . grayscale_newaxis :
obs = np . expand_dims ( obs , axis = - 1 ) # Add a channel axis
2019-05-10 23:59:32 +02:00
return obs