2019-06-08 01:01:35 +02:00
import numpy as np
2021-07-28 22:21:47 -04:00
import warnings
2019-06-08 01:01:35 +02:00
from gym . spaces import Box
from gym import ObservationWrapper
class ResizeObservation ( ObservationWrapper ) :
2021-07-29 02:26:34 +02:00
r """ Downsample the image observation to a square image. """
2019-06-08 01:01:35 +02:00
def __init__ ( self , env , shape ) :
super ( ResizeObservation , self ) . __init__ ( env )
if isinstance ( shape , int ) :
shape = ( shape , shape )
assert all ( x > 0 for x in shape ) , shape
2021-07-29 12:42:48 -04:00
warnings . warn (
" Gym ' s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit "
)
2019-06-08 01:01:35 +02:00
self . shape = tuple ( shape )
2019-07-13 06:10:11 +08:00
obs_shape = self . shape + self . observation_space . shape [ 2 : ]
2019-06-08 01:01:35 +02:00
self . observation_space = Box ( low = 0 , high = 255 , shape = obs_shape , dtype = np . uint8 )
def observation ( self , observation ) :
import cv2
2021-07-29 02:26:34 +02:00
2021-07-29 12:42:48 -04:00
observation = cv2 . resize ( observation , self . shape [ : : - 1 ] , interpolation = cv2 . INTER_AREA )
2019-06-08 01:01:35 +02:00
if observation . ndim == 2 :
observation = np . expand_dims ( observation , - 1 )
return observation