diff --git a/baselines/common/atari_wrappers.py b/baselines/common/atari_wrappers.py index 210b0fc..6f551a7 100644 --- a/baselines/common/atari_wrappers.py +++ b/baselines/common/atari_wrappers.py @@ -129,18 +129,26 @@ class ClipRewardEnv(gym.RewardWrapper): return np.sign(reward) class WarpFrame(gym.ObservationWrapper): - def __init__(self, env, width=84, height=84): + def __init__(self, env, width=84, height=84, grayscale=True): """Warp frames to 84x84 as done in the Nature paper and later work.""" gym.ObservationWrapper.__init__(self, env) self.width = width self.height = height - self.observation_space = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 1), dtype=np.uint8) + self.grayscale = grayscale + if self.grayscale: + self.observation_space = spaces.Box(low=0, high=255, + shape=(self.height, self.width, 1), dtype=np.uint8) + else: + self.observation_space = spaces.Box(low=0, high=255, + shape=(self.height, self.width, 3), dtype=np.uint8) def observation(self, frame): - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + if self.grayscale: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) - return frame[:, :, None] + if self.grayscale: + frame = np.expand_dims(frame, -1) + return frame class FrameStack(gym.Wrapper): def __init__(self, env, k): @@ -156,7 +164,7 @@ class FrameStack(gym.Wrapper): self.k = k self.frames = deque([], maxlen=k) shp = env.observation_space.shape - self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype) + self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) def reset(self): ob = self.env.reset() @@ -197,7 +205,7 @@ class LazyFrames(object): def _force(self): if self._out is None: - self._out = np.concatenate(self._frames, axis=2) + self._out = np.concatenate(self._frames, axis=-1) self._frames = None return self._out