2019-08-23 15:02:33 -07:00
|
|
|
import collections
|
2020-12-17 14:55:52 -05:00
|
|
|
from collections.abc import MutableMapping
|
2019-08-23 15:02:33 -07:00
|
|
|
import copy
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from gym import spaces
|
|
|
|
from gym import ObservationWrapper
|
|
|
|
|
2020-12-17 14:55:52 -05:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
STATE_KEY = "state"
|
2019-08-23 15:02:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
class PixelObservationWrapper(ObservationWrapper):
|
|
|
|
"""Augment observations by pixel values."""
|
|
|
|
|
2021-07-29 15:39:42 -04:00
|
|
|
def __init__(
|
|
|
|
self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)
|
|
|
|
):
|
2019-08-23 15:02:33 -07:00
|
|
|
"""Initializes a new pixel Wrapper.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
env: The environment to wrap.
|
|
|
|
pixels_only: If `True` (default), the original observation returned
|
|
|
|
by the wrapped environment will be discarded, and a dictionary
|
|
|
|
observation will only include pixels. If `False`, the
|
|
|
|
observation dictionary will contain both the original
|
|
|
|
observations and the pixel observations.
|
|
|
|
render_kwargs: Optional `dict` containing keyword arguments passed
|
|
|
|
to the `self.render` method.
|
|
|
|
pixel_keys: Optional custom string specifying the pixel
|
|
|
|
observation's key in the `OrderedDict` of observations.
|
|
|
|
Defaults to 'pixels'.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If `env`'s observation spec is not compatible with the
|
|
|
|
wrapper. Supported formats are a single array, or a dict of
|
|
|
|
arrays.
|
|
|
|
ValueError: If `env`'s observation already contains any of the
|
|
|
|
specified `pixel_keys`.
|
|
|
|
"""
|
|
|
|
|
2021-11-14 01:53:06 +01:00
|
|
|
super().__init__(env)
|
2019-08-23 15:02:33 -07:00
|
|
|
|
|
|
|
if render_kwargs is None:
|
|
|
|
render_kwargs = {}
|
|
|
|
|
|
|
|
for key in pixel_keys:
|
|
|
|
render_kwargs.setdefault(key, {})
|
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
render_mode = render_kwargs[key].pop("mode", "rgb_array")
|
|
|
|
assert render_mode == "rgb_array", render_mode
|
|
|
|
render_kwargs[key]["mode"] = "rgb_array"
|
2019-08-23 15:02:33 -07:00
|
|
|
|
|
|
|
wrapped_observation_space = env.observation_space
|
|
|
|
|
|
|
|
if isinstance(wrapped_observation_space, spaces.Box):
|
|
|
|
self._observation_is_dict = False
|
2021-11-14 01:53:06 +01:00
|
|
|
invalid_keys = {STATE_KEY}
|
2021-07-29 02:26:34 +02:00
|
|
|
elif isinstance(wrapped_observation_space, (spaces.Dict, MutableMapping)):
|
2019-08-23 15:02:33 -07:00
|
|
|
self._observation_is_dict = True
|
|
|
|
invalid_keys = set(wrapped_observation_space.spaces.keys())
|
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported observation space structure.")
|
|
|
|
|
|
|
|
if not pixels_only:
|
|
|
|
# Make sure that now keys in the `pixel_keys` overlap with
|
|
|
|
# `observation_keys`
|
|
|
|
overlapping_keys = set(pixel_keys) & set(invalid_keys)
|
|
|
|
if overlapping_keys:
|
2021-07-29 15:39:42 -04:00
|
|
|
raise ValueError(
|
2021-11-14 01:53:06 +01:00
|
|
|
f"Duplicate or reserved pixel keys {overlapping_keys!r}."
|
2021-07-29 15:39:42 -04:00
|
|
|
)
|
2019-08-23 15:02:33 -07:00
|
|
|
|
|
|
|
if pixels_only:
|
|
|
|
self.observation_space = spaces.Dict()
|
|
|
|
elif self._observation_is_dict:
|
|
|
|
self.observation_space = copy.deepcopy(wrapped_observation_space)
|
|
|
|
else:
|
|
|
|
self.observation_space = spaces.Dict()
|
|
|
|
self.observation_space.spaces[STATE_KEY] = wrapped_observation_space
|
|
|
|
|
|
|
|
# Extend observation space with pixels.
|
|
|
|
|
|
|
|
pixels_spaces = {}
|
|
|
|
for pixel_key in pixel_keys:
|
2019-10-18 23:57:17 +02:00
|
|
|
pixels = self.env.render(**render_kwargs[pixel_key])
|
2019-08-23 15:02:33 -07:00
|
|
|
|
|
|
|
if np.issubdtype(pixels.dtype, np.integer):
|
|
|
|
low, high = (0, 255)
|
|
|
|
elif np.issubdtype(pixels.dtype, np.float):
|
2021-07-29 02:26:34 +02:00
|
|
|
low, high = (-float("inf"), float("inf"))
|
2019-08-23 15:02:33 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(pixels.dtype)
|
|
|
|
|
2021-07-29 15:39:42 -04:00
|
|
|
pixels_space = spaces.Box(
|
|
|
|
shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
|
|
|
|
)
|
2019-08-23 15:02:33 -07:00
|
|
|
pixels_spaces[pixel_key] = pixels_space
|
|
|
|
|
|
|
|
self.observation_space.spaces.update(pixels_spaces)
|
|
|
|
|
|
|
|
self._env = env
|
|
|
|
self._pixels_only = pixels_only
|
|
|
|
self._render_kwargs = render_kwargs
|
|
|
|
self._pixel_keys = pixel_keys
|
|
|
|
|
|
|
|
def observation(self, observation):
|
|
|
|
pixel_observation = self._add_pixel_observation(observation)
|
|
|
|
return pixel_observation
|
|
|
|
|
2020-06-19 17:20:12 -04:00
|
|
|
def _add_pixel_observation(self, wrapped_observation):
|
2019-08-23 15:02:33 -07:00
|
|
|
if self._pixels_only:
|
|
|
|
observation = collections.OrderedDict()
|
|
|
|
elif self._observation_is_dict:
|
2020-06-19 17:20:12 -04:00
|
|
|
observation = type(wrapped_observation)(wrapped_observation)
|
2019-08-23 15:02:33 -07:00
|
|
|
else:
|
|
|
|
observation = collections.OrderedDict()
|
2020-06-19 17:20:12 -04:00
|
|
|
observation[STATE_KEY] = wrapped_observation
|
2019-08-23 15:02:33 -07:00
|
|
|
|
2021-07-29 15:39:42 -04:00
|
|
|
pixel_observations = {
|
|
|
|
pixel_key: self.env.render(**self._render_kwargs[pixel_key])
|
|
|
|
for pixel_key in self._pixel_keys
|
|
|
|
}
|
2019-08-23 15:02:33 -07:00
|
|
|
|
|
|
|
observation.update(pixel_observations)
|
|
|
|
|
|
|
|
return observation
|