mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Add shape property and equality operation to LazyFrames (#1862)
* Add shape property and equality operation to LazyFrames for simplified usage in tests * Fix shape property with lz4_compress
This commit is contained in:
@@ -2,7 +2,7 @@ from collections import deque
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Box
|
||||
from gym import ObservationWrapper
|
||||
from gym import Wrapper
|
||||
|
||||
|
||||
class LazyFrames(object):
|
||||
@@ -19,7 +19,7 @@ class LazyFrames(object):
|
||||
def __init__(self, frames, lz4_compress=False):
|
||||
if lz4_compress:
|
||||
from lz4.block import compress
|
||||
self.shape = frames[0].shape
|
||||
self.frame_shape = frames[0].shape
|
||||
self.dtype = frames[0].dtype
|
||||
frames = [compress(frame) for frame in frames]
|
||||
self._frames = frames
|
||||
@@ -28,7 +28,7 @@ class LazyFrames(object):
|
||||
def __array__(self, dtype=None):
|
||||
if self.lz4_compress:
|
||||
from lz4.block import decompress
|
||||
frames = [np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.shape) for frame in self._frames]
|
||||
frames = [np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape) for frame in self._frames]
|
||||
else:
|
||||
frames = self._frames
|
||||
out = np.stack(frames, axis=0)
|
||||
@@ -42,8 +42,15 @@ class LazyFrames(object):
|
||||
def __getitem__(self, i):
|
||||
return self.__array__()[i]
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__array__() == other
|
||||
|
||||
class FrameStack(ObservationWrapper):
|
||||
@property
|
||||
def shape(self):
|
||||
return self.__array__().shape
|
||||
|
||||
|
||||
class FrameStack(Wrapper):
|
||||
r"""Observation wrapper that stacks the observations in a rolling manner.
|
||||
|
||||
For example, if the number of stacks is 4, then the returned observation contains
|
||||
|
Reference in New Issue
Block a user