mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 15:11:51 +00:00
Add shape property and equality operation to LazyFrames (#1862)
* Add shape property and equality operation to LazyFrames for simplified usage in tests * Fix shape property with lz4_compress
This commit is contained in:
@@ -2,7 +2,7 @@ from collections import deque
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
from gym import ObservationWrapper
|
from gym import Wrapper
|
||||||
|
|
||||||
|
|
||||||
class LazyFrames(object):
|
class LazyFrames(object):
|
||||||
@@ -19,7 +19,7 @@ class LazyFrames(object):
|
|||||||
def __init__(self, frames, lz4_compress=False):
|
def __init__(self, frames, lz4_compress=False):
|
||||||
if lz4_compress:
|
if lz4_compress:
|
||||||
from lz4.block import compress
|
from lz4.block import compress
|
||||||
self.shape = frames[0].shape
|
self.frame_shape = frames[0].shape
|
||||||
self.dtype = frames[0].dtype
|
self.dtype = frames[0].dtype
|
||||||
frames = [compress(frame) for frame in frames]
|
frames = [compress(frame) for frame in frames]
|
||||||
self._frames = frames
|
self._frames = frames
|
||||||
@@ -28,7 +28,7 @@ class LazyFrames(object):
|
|||||||
def __array__(self, dtype=None):
|
def __array__(self, dtype=None):
|
||||||
if self.lz4_compress:
|
if self.lz4_compress:
|
||||||
from lz4.block import decompress
|
from lz4.block import decompress
|
||||||
frames = [np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.shape) for frame in self._frames]
|
frames = [np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape) for frame in self._frames]
|
||||||
else:
|
else:
|
||||||
frames = self._frames
|
frames = self._frames
|
||||||
out = np.stack(frames, axis=0)
|
out = np.stack(frames, axis=0)
|
||||||
@@ -42,8 +42,15 @@ class LazyFrames(object):
|
|||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
return self.__array__()[i]
|
return self.__array__()[i]
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.__array__() == other
|
||||||
|
|
||||||
class FrameStack(ObservationWrapper):
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.__array__().shape
|
||||||
|
|
||||||
|
|
||||||
|
class FrameStack(Wrapper):
|
||||||
r"""Observation wrapper that stacks the observations in a rolling manner.
|
r"""Observation wrapper that stacks the observations in a rolling manner.
|
||||||
|
|
||||||
For example, if the number of stacks is 4, then the returned observation contains
|
For example, if the number of stacks is 4, then the returned observation contains
|
||||||
|
Reference in New Issue
Block a user