2019-08-23 23:04:11 +02:00
from collections import deque
import numpy as np
2021-07-28 22:21:47 -04:00
import warnings
2019-08-23 23:04:11 +02:00
from gym . spaces import Box
2020-04-11 00:10:10 +02:00
from gym import Wrapper
2019-08-23 23:04:11 +02:00
class LazyFrames ( object ) :
2020-06-06 08:01:04 +10:00
r """ Ensures common frames are only stored once to optimize memory use.
2019-08-23 23:04:11 +02:00
2020-06-06 08:01:04 +10:00
To further reduce the memory use , it is optionally to turn on lz4 to
2019-08-23 23:04:11 +02:00
compress the observations .
. . note : :
2020-06-06 08:01:04 +10:00
This object should only be converted to numpy array just before forward pass .
Args :
lz4_compress ( bool ) : use lz4 to compress the frames internally
2019-08-23 23:04:11 +02:00
"""
2021-07-29 02:26:34 +02:00
__slots__ = ( " frame_shape " , " dtype " , " shape " , " lz4_compress " , " _frames " )
2020-06-06 08:01:04 +10:00
2019-08-23 23:04:11 +02:00
def __init__ ( self , frames , lz4_compress = False ) :
2021-07-29 12:42:48 -04:00
warnings . warn (
" Gym ' s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit "
)
2020-06-20 07:16:02 +10:00
self . frame_shape = tuple ( frames [ 0 ] . shape )
self . shape = ( len ( frames ) , ) + self . frame_shape
self . dtype = frames [ 0 ] . dtype
2019-08-23 23:04:11 +02:00
if lz4_compress :
from lz4 . block import compress
2021-07-29 02:26:34 +02:00
2019-08-23 23:04:11 +02:00
frames = [ compress ( frame ) for frame in frames ]
self . _frames = frames
self . lz4_compress = lz4_compress
def __array__ ( self , dtype = None ) :
2020-06-06 08:01:04 +10:00
arr = self [ : ]
2019-08-23 23:04:11 +02:00
if dtype is not None :
2020-06-06 08:01:04 +10:00
return arr . astype ( dtype )
return arr
2019-08-23 23:04:11 +02:00
def __len__ ( self ) :
2020-06-06 08:01:04 +10:00
return self . shape [ 0 ]
2019-08-23 23:04:11 +02:00
2020-06-06 08:01:04 +10:00
def __getitem__ ( self , int_or_slice ) :
if isinstance ( int_or_slice , int ) :
return self . _check_decompress ( self . _frames [ int_or_slice ] ) # single frame
2021-07-29 12:42:48 -04:00
return np . stack ( [ self . _check_decompress ( f ) for f in self . _frames [ int_or_slice ] ] , axis = 0 )
2019-08-23 23:04:11 +02:00
2020-04-11 00:10:10 +02:00
def __eq__ ( self , other ) :
return self . __array__ ( ) == other
2019-08-23 23:04:11 +02:00
2020-06-06 08:01:04 +10:00
def _check_decompress ( self , frame ) :
if self . lz4_compress :
from lz4 . block import decompress
2021-07-29 02:26:34 +02:00
2021-07-29 12:42:48 -04:00
return np . frombuffer ( decompress ( frame ) , dtype = self . dtype ) . reshape ( self . frame_shape )
2020-06-06 08:01:04 +10:00
return frame
2020-04-11 00:10:10 +02:00
class FrameStack ( Wrapper ) :
2020-06-06 08:01:04 +10:00
r """ Observation wrapper that stacks the observations in a rolling manner.
2019-08-23 23:04:11 +02:00
For example , if the number of stacks is 4 , then the returned observation contains
the most recent 4 observations . For environment ' Pendulum-v0 ' , the original observation
is an array with shape [ 3 ] , so if we stack 4 observations , the processed observation
2020-06-06 08:01:04 +10:00
has shape [ 4 , 3 ] .
2019-08-23 23:04:11 +02:00
. . note : :
To be memory efficient , the stacked observations are wrapped by : class : ` LazyFrame ` .
. . note : :
The observation space must be ` Box ` type . If one uses ` Dict `
2020-06-06 08:01:04 +10:00
as observation space , it should apply ` FlattenDictWrapper ` at first .
2019-08-23 23:04:11 +02:00
Example : :
>> > import gym
>> > env = gym . make ( ' PongNoFrameskip-v0 ' )
>> > env = FrameStack ( env , 4 )
>> > env . observation_space
Box ( 4 , 210 , 160 , 3 )
Args :
env ( Env ) : environment object
num_stack ( int ) : number of stacks
2020-06-06 08:01:04 +10:00
lz4_compress ( bool ) : use lz4 to compress the frames internally
2019-08-23 23:04:11 +02:00
"""
2021-07-29 02:26:34 +02:00
2019-08-23 23:04:11 +02:00
def __init__ ( self , env , num_stack , lz4_compress = False ) :
super ( FrameStack , self ) . __init__ ( env )
self . num_stack = num_stack
self . lz4_compress = lz4_compress
self . frames = deque ( maxlen = num_stack )
low = np . repeat ( self . observation_space . low [ np . newaxis , . . . ] , num_stack , axis = 0 )
2021-07-29 12:42:48 -04:00
high = np . repeat ( self . observation_space . high [ np . newaxis , . . . ] , num_stack , axis = 0 )
self . observation_space = Box ( low = low , high = high , dtype = self . observation_space . dtype )
2019-08-23 23:04:11 +02:00
def _get_observation ( self ) :
assert len ( self . frames ) == self . num_stack , ( len ( self . frames ) , self . num_stack )
return LazyFrames ( list ( self . frames ) , self . lz4_compress )
def step ( self , action ) :
observation , reward , done , info = self . env . step ( action )
self . frames . append ( observation )
return self . _get_observation ( ) , reward , done , info
def reset ( self , * * kwargs ) :
observation = self . env . reset ( * * kwargs )
[ self . frames . append ( observation ) for _ in range ( self . num_stack ) ]
return self . _get_observation ( )