diff --git a/baselines/common/atari_wrappers.py b/baselines/common/atari_wrappers.py index 2c9b8c6..b3e2131 100644 --- a/baselines/common/atari_wrappers.py +++ b/baselines/common/atari_wrappers.py @@ -130,27 +130,60 @@ class ClipRewardEnv(gym.RewardWrapper): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) -class WarpFrame(gym.ObservationWrapper): - def __init__(self, env, width=84, height=84, grayscale=True): - """Warp frames to 84x84 as done in the Nature paper and later work.""" - gym.ObservationWrapper.__init__(self, env) - self.width = width - self.height = height - self.grayscale = grayscale - if self.grayscale: - self.observation_space = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 1), dtype=np.uint8) - else: - self.observation_space = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 3), dtype=np.uint8) - def observation(self, frame): - if self.grayscale: +class WarpFrame(gym.ObservationWrapper): + def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): + """ + Warp frames to 84x84 as done in the Nature paper and later work. + + If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which + observation should be warped. + """ + super().__init__(env) + self._width = width + self._height = height + self._grayscale = grayscale + self._key = dict_space_key + if self._grayscale: + num_colors = 1 + else: + num_colors = 3 + + new_space = gym.spaces.Box( + low=0, + high=255, + shape=(self._height, self._width, num_colors), + dtype=np.uint8, + ) + if self._key is None: + original_space = self.observation_space + self.observation_space = new_space + else: + original_space = self.observation_space.spaces[self._key] + self.observation_space.spaces[self._key] = new_space + assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 + + def observation(self, obs): + if self._key is None: + frame = obs + else: + frame = obs[self._key] + + if self._grayscale: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) - if self.grayscale: + frame = cv2.resize( + frame, (self._width, self._height), interpolation=cv2.INTER_AREA + ) + if self._grayscale: frame = np.expand_dims(frame, -1) - return frame + + if self._key is None: + obs = frame + else: + obs = obs.copy() + obs[self._key] = frame + return obs + class FrameStack(gym.Wrapper): def __init__(self, env, k):