diff --git a/baselines/common/atari_wrappers.py b/baselines/common/atari_wrappers.py index fb768e8..8be0a2f 100644 --- a/baselines/common/atari_wrappers.py +++ b/baselines/common/atari_wrappers.py @@ -254,6 +254,13 @@ class LazyFrames(object): return len(self._force()) def __getitem__(self, i): + return self._force()[i] + + def count(self): + frames = self._force() + return frames.shape[frames.ndim - 1] + + def frame(self, i): return self._force()[..., i] def make_atari(env_id, max_episode_steps=None):