support color vs. grayscale option in WarpFrame wrapper (#166)
* support color vs. grayscale option in WarpFrame wrapper * Support color in other wrappers * Updated per Peters suggestions
This commit is contained in:
committed by
Peter Zhokhov
parent
e619e42364
commit
63151af41a
@@ -129,18 +129,26 @@ class ClipRewardEnv(gym.RewardWrapper):
|
|||||||
return np.sign(reward)
|
return np.sign(reward)
|
||||||
|
|
||||||
class WarpFrame(gym.ObservationWrapper):
|
class WarpFrame(gym.ObservationWrapper):
|
||||||
def __init__(self, env, width=84, height=84):
|
def __init__(self, env, width=84, height=84, grayscale=True):
|
||||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||||
gym.ObservationWrapper.__init__(self, env)
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.observation_space = spaces.Box(low=0, high=255,
|
self.grayscale = grayscale
|
||||||
shape=(self.height, self.width, 1), dtype=np.uint8)
|
if self.grayscale:
|
||||||
|
self.observation_space = spaces.Box(low=0, high=255,
|
||||||
|
shape=(self.height, self.width, 1), dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
self.observation_space = spaces.Box(low=0, high=255,
|
||||||
|
shape=(self.height, self.width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
def observation(self, frame):
|
def observation(self, frame):
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
if self.grayscale:
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||||
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||||
return frame[:, :, None]
|
if self.grayscale:
|
||||||
|
frame = np.expand_dims(frame, -1)
|
||||||
|
return frame
|
||||||
|
|
||||||
class FrameStack(gym.Wrapper):
|
class FrameStack(gym.Wrapper):
|
||||||
def __init__(self, env, k):
|
def __init__(self, env, k):
|
||||||
@@ -156,7 +164,7 @@ class FrameStack(gym.Wrapper):
|
|||||||
self.k = k
|
self.k = k
|
||||||
self.frames = deque([], maxlen=k)
|
self.frames = deque([], maxlen=k)
|
||||||
shp = env.observation_space.shape
|
shp = env.observation_space.shape
|
||||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
ob = self.env.reset()
|
ob = self.env.reset()
|
||||||
@@ -197,7 +205,7 @@ class LazyFrames(object):
|
|||||||
|
|
||||||
def _force(self):
|
def _force(self):
|
||||||
if self._out is None:
|
if self._out is None:
|
||||||
self._out = np.concatenate(self._frames, axis=2)
|
self._out = np.concatenate(self._frames, axis=-1)
|
||||||
self._frames = None
|
self._frames = None
|
||||||
return self._out
|
return self._out
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user