import numpy as np import warnings from gym.spaces import Box from gym import ObservationWrapper class ResizeObservation(ObservationWrapper): r"""Downsample the image observation to a square image.""" def __init__(self, env, shape): super(ResizeObservation, self).__init__(env) if isinstance(shape, int): shape = (shape, shape) assert all(x > 0 for x in shape), shape warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") self.shape = tuple(shape) obs_shape = self.shape + self.observation_space.shape[2:] self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) def observation(self, observation): import cv2 observation = cv2.resize( observation, self.shape[::-1], interpolation=cv2.INTER_AREA ) if observation.ndim == 2: observation = np.expand_dims(observation, -1) return observation