2019-08-23 23:47:07 +02:00
import numpy as np
2021-07-28 22:21:47 -04:00
import warnings
2019-08-23 23:47:07 +02:00
from gym . spaces import Box
from gym import ObservationWrapper
class GrayScaleObservation ( ObservationWrapper ) :
2021-07-29 02:26:34 +02:00
r """ Convert the image observation from RGB to gray scale. """
2019-08-23 23:47:07 +02:00
def __init__ ( self , env , keep_dim = False ) :
super ( GrayScaleObservation , self ) . __init__ ( env )
self . keep_dim = keep_dim
2021-07-29 02:26:34 +02:00
assert (
len ( env . observation_space . shape ) == 3
and env . observation_space . shape [ - 1 ] == 3
)
2021-07-28 22:21:47 -04:00
warnings . warn ( " Gym \' s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit " )
2019-08-23 23:47:07 +02:00
obs_shape = self . observation_space . shape [ : 2 ]
if self . keep_dim :
2021-07-29 02:26:34 +02:00
self . observation_space = Box (
low = 0 , high = 255 , shape = ( obs_shape [ 0 ] , obs_shape [ 1 ] , 1 ) , dtype = np . uint8
)
2019-08-23 23:47:07 +02:00
else :
2021-07-29 02:26:34 +02:00
self . observation_space = Box (
low = 0 , high = 255 , shape = obs_shape , dtype = np . uint8
)
2019-08-23 23:47:07 +02:00
def observation ( self , observation ) :
import cv2
2021-07-29 02:26:34 +02:00
2019-08-23 23:47:07 +02:00
observation = cv2 . cvtColor ( observation , cv2 . COLOR_RGB2GRAY )
if self . keep_dim :
observation = np . expand_dims ( observation , - 1 )
return observation