Files
Gymnasium/gym/wrappers/frame_stack.py
Mark Towers bf093c6890 Update the flake8 pre-commit ignores (#2778)
* Remove additional ignores from flake8

* Remove all unused imports

* Remove all unused imports

* Update flake8 and pyupgrade

* F841, removed unused variables

* E731, removed lambda assignment to variables

* Remove E731, F403, F405, F524

* Remove E722, bare exceptions

* Remove E712, compare variable == True or == False to is True or is False

* Remove E402, module level import not at top of file

* Added --pre-file-ignores

* Add --per-file-ignores removing E741, E302 and E704

* Add E741, do not use variables named ‘l’, ‘O’, or ‘I’ to ignore issues in classic control

* Fixed issues for pytest==6.2

* Remove unnecessary # noqa

* Edit comment with the removal of E302

* Added warnings and declared module, attr for pyright type hinting

* Remove unused import

* Removed flake8 E302

* Updated flake8 from 3.9.2 to 4.0.1

* Remove unused variable
2022-04-26 11:18:37 -04:00

133 lines
4.1 KiB
Python

from collections import deque
import numpy as np
from gym import ObservationWrapper
from gym.spaces import Box
class LazyFrames:
r"""Ensures common frames are only stored once to optimize memory use.
To further reduce the memory use, it is optionally to turn on lz4 to
compress the observations.
.. note::
This object should only be converted to numpy array just before forward pass.
Args:
lz4_compress (bool): use lz4 to compress the frames internally
"""
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
def __init__(self, frames, lz4_compress=False):
self.frame_shape = tuple(frames[0].shape)
self.shape = (len(frames),) + self.frame_shape
self.dtype = frames[0].dtype
if lz4_compress:
from lz4.block import compress
frames = [compress(frame) for frame in frames]
self._frames = frames
self.lz4_compress = lz4_compress
def __array__(self, dtype=None):
arr = self[:]
if dtype is not None:
return arr.astype(dtype)
return arr
def __len__(self):
return self.shape[0]
def __getitem__(self, int_or_slice):
if isinstance(int_or_slice, int):
return self._check_decompress(self._frames[int_or_slice]) # single frame
return np.stack(
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
)
def __eq__(self, other):
return self.__array__() == other
def _check_decompress(self, frame):
if self.lz4_compress:
from lz4.block import decompress
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
self.frame_shape
)
return frame
class FrameStack(ObservationWrapper):
r"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
the most recent 4 observations. For environment 'Pendulum-v1', the original observation
is an array with shape [3], so if we stack 4 observations, the processed observation
has shape [4, 3].
.. note::
To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`.
.. note::
The observation space must be `Box` type. If one uses `Dict`
as observation space, it should apply `FlattenDictWrapper` at first.
Example::
>>> import gym
>>> env = gym.make('PongNoFrameskip-v0')
>>> env = FrameStack(env, 4)
>>> env.observation_space
Box(4, 210, 160, 3)
Args:
env (Env): environment object
num_stack (int): number of stacks
lz4_compress (bool): use lz4 to compress the frames internally
"""
def __init__(self, env, num_stack, lz4_compress=False):
super().__init__(env)
self.num_stack = num_stack
self.lz4_compress = lz4_compress
self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
)
self.observation_space = Box(
low=low, high=high, dtype=self.observation_space.dtype
)
def observation(self):
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
return LazyFrames(list(self.frames), self.lz4_compress)
def step(self, action):
observation, reward, done, info = self.env.step(action)
self.frames.append(observation)
return self.observation(), reward, done, info
def reset(self, **kwargs):
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
else:
obs = self.env.reset(**kwargs)
info = None # Unused
[self.frames.append(obs) for _ in range(self.num_stack)]
if kwargs.get("return_info", False):
return self.observation(), info
else:
return self.observation()