Add shape property and equality operation to LazyFrames (#1862)

* Add shape property and equality operation to LazyFrames for simplified usage in tests

* Fix shape property with lz4_compress
This commit is contained in:
johannespitz
2020-04-11 00:10:10 +02:00
committed by GitHub
parent 9a7f13911c
commit a8a3d36353

View File

@@ -2,7 +2,7 @@ from collections import deque
import numpy as np
from gym.spaces import Box
from gym import ObservationWrapper
from gym import Wrapper
class LazyFrames(object):
@@ -19,7 +19,7 @@ class LazyFrames(object):
def __init__(self, frames, lz4_compress=False):
if lz4_compress:
from lz4.block import compress
self.shape = frames[0].shape
self.frame_shape = frames[0].shape
self.dtype = frames[0].dtype
frames = [compress(frame) for frame in frames]
self._frames = frames
@@ -28,7 +28,7 @@ class LazyFrames(object):
def __array__(self, dtype=None):
if self.lz4_compress:
from lz4.block import decompress
frames = [np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.shape) for frame in self._frames]
frames = [np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape) for frame in self._frames]
else:
frames = self._frames
out = np.stack(frames, axis=0)
@@ -42,8 +42,15 @@ class LazyFrames(object):
def __getitem__(self, i):
return self.__array__()[i]
def __eq__(self, other):
return self.__array__() == other
class FrameStack(ObservationWrapper):
@property
def shape(self):
return self.__array__().shape
class FrameStack(Wrapper):
r"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains