redo black

This commit is contained in:
Justin Terry
2021-07-29 12:42:48 -04:00
parent d5004b7ec1
commit e9d2c41f2b
109 changed files with 459 additions and 1363 deletions

View File

@@ -22,7 +22,9 @@ class LazyFrames(object):
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
def __init__(self, frames, lz4_compress=False):
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
self.frame_shape = tuple(frames[0].shape)
self.shape = (len(frames),) + self.frame_shape
self.dtype = frames[0].dtype
@@ -45,9 +47,7 @@ class LazyFrames(object):
def __getitem__(self, int_or_slice):
if isinstance(int_or_slice, int):
return self._check_decompress(self._frames[int_or_slice]) # single frame
return np.stack(
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
)
return np.stack([self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0)
def __eq__(self, other):
return self.__array__() == other
@@ -56,9 +56,7 @@ class LazyFrames(object):
if self.lz4_compress:
from lz4.block import decompress
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
self.frame_shape
)
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape)
return frame
@@ -102,12 +100,8 @@ class FrameStack(Wrapper):
self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
)
self.observation_space = Box(
low=low, high=high, dtype=self.observation_space.dtype
)
high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
def _get_observation(self):
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)