Files
Gymnasium/gymnasium/wrappers/frame_stack.py

191 lines
6.2 KiB
Python
Raw Normal View History

2022-05-13 13:58:19 +01:00
"""Wrapper that stacks frames."""
from collections import deque
2022-05-13 13:58:19 +01:00
from typing import Union
Seeding update (#2422) * Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
2021-12-08 22:14:15 +01:00
import numpy as np
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box
class LazyFrames:
2022-05-13 13:58:19 +01:00
"""Ensures common frames are only stored once to optimize memory use.
2022-05-13 13:58:19 +01:00
To further reduce the memory use, it is optionally to turn on lz4 to compress the observations.
2022-05-13 13:58:19 +01:00
Note:
This object should only be converted to numpy array just before forward pass.
"""
2022-05-13 13:58:19 +01:00
2021-07-29 02:26:34 +02:00
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
2022-05-13 13:58:19 +01:00
def __init__(self, frames: list, lz4_compress: bool = False):
"""Lazyframe for a set of frames and if to apply lz4.
Args:
frames (list): The frames to convert to lazy frames
lz4_compress (bool): Use lz4 to compress the frames internally
Raises:
DependencyNotInstalled: lz4 is not installed
2022-05-13 13:58:19 +01:00
"""
self.frame_shape = tuple(frames[0].shape)
self.shape = (len(frames),) + self.frame_shape
self.dtype = frames[0].dtype
if lz4_compress:
2022-05-13 13:58:19 +01:00
try:
from lz4.block import compress
except ImportError:
raise DependencyNotInstalled(
2022-09-08 10:10:07 +01:00
"lz4 is not installed, run `pip install gymnasium[other]`"
2022-05-13 13:58:19 +01:00
)
2021-07-29 02:26:34 +02:00
frames = [compress(frame) for frame in frames]
self._frames = frames
self.lz4_compress = lz4_compress
def __array__(self, dtype=None):
2022-05-13 13:58:19 +01:00
"""Gets a numpy array of stacked frames with specific dtype.
Args:
dtype: The dtype of the stacked frames
Returns:
The array of stacked frames with dtype
"""
arr = self[:]
if dtype is not None:
return arr.astype(dtype)
return arr
def __len__(self):
2022-05-13 13:58:19 +01:00
"""Returns the number of frame stacks.
Returns:
The number of frame stacks
"""
return self.shape[0]
2022-05-13 13:58:19 +01:00
def __getitem__(self, int_or_slice: Union[int, slice]):
"""Gets the stacked frames for a particular index or slice.
Args:
int_or_slice: Index or slice to get items for
Returns:
np.stacked frames for the int or slice
"""
if isinstance(int_or_slice, int):
return self._check_decompress(self._frames[int_or_slice]) # single frame
2021-07-29 15:39:42 -04:00
return np.stack(
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
)
def __eq__(self, other):
2022-05-13 13:58:19 +01:00
"""Checks that the current frames are equal to the other object."""
return self.__array__() == other
def _check_decompress(self, frame):
if self.lz4_compress:
from lz4.block import decompress
2021-07-29 02:26:34 +02:00
2021-07-29 15:39:42 -04:00
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
self.frame_shape
)
return frame
class FrameStack(gym.ObservationWrapper):
2022-05-13 13:58:19 +01:00
"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
the most recent 4 observations. For environment 'Pendulum-v1', the original observation
is an array with shape [3], so if we stack 4 observations, the processed observation
has shape [4, 3].
2022-05-13 13:58:19 +01:00
Note:
- To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`.
- The observation space must be :class:`Box` type. If one uses :class:`Dict`
as observation space, it should apply :class:`FlattenObservation` wrapper first.
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation. I.e. the observation returned by :meth:`reset` will consist of ``num_stack`-many identical frames,
2022-05-13 13:58:19 +01:00
Example:
>>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1')
>>> env = FrameStack(env, 4)
>>> env.observation_space
2022-05-13 13:58:19 +01:00
Box(4, 96, 96, 3)
>>> obs = env.reset()
>>> obs.shape
(4, 96, 96, 3)
"""
2021-07-29 02:26:34 +02:00
def __init__(
self,
env: gym.Env,
num_stack: int,
lz4_compress: bool = False,
):
2022-05-13 13:58:19 +01:00
"""Observation wrapper that stacks the observations in a rolling manner.
Args:
env (Env): The environment to apply the wrapper
num_stack (int): The number of frames to stack
lz4_compress (bool): Use lz4 to compress the frames internally
"""
super().__init__(env)
self.num_stack = num_stack
self.lz4_compress = lz4_compress
self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
2021-07-29 15:39:42 -04:00
high = np.repeat(
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
)
self.observation_space = Box(
low=low, high=high, dtype=self.observation_space.dtype
)
2022-05-13 13:58:19 +01:00
def observation(self, observation):
"""Converts the wrappers current frames to lazy frames.
Args:
observation: Ignored
Returns:
:class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames`
"""
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
return LazyFrames(list(self.frames), self.lz4_compress)
def step(self, action):
2022-05-13 13:58:19 +01:00
"""Steps through the environment, appending the observation to the frame buffer.
Args:
action: The action to step through the environment with
Returns:
Stacked observations, reward, terminated, truncated, and information from the environment
2022-05-13 13:58:19 +01:00
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self.frames.append(observation)
return self.observation(None), reward, terminated, truncated, info
def reset(self, **kwargs):
2022-05-13 13:58:19 +01:00
"""Reset the environment with kwargs.
Args:
**kwargs: The kwargs for the environment reset
Returns:
The stacked observations
"""
obs, info = self.env.reset(**kwargs)
[self.frames.append(obs) for _ in range(self.num_stack)]
return self.observation(None), info