Fix return_info for wrappers (#2612)

* Fix `return_info` for Observation wrappers, Atari (?) and framestack

* Make type checkers and IDEs happier

* Merge in #2454

* Update the info dict based on no-op steps
Some type hints

* Bug fix

* Handle resets during frameskip
This commit is contained in:
Ariel Kwiatkowski
2022-02-17 18:03:35 +01:00
committed by GitHub
parent 8dc1b52d21
commit 3fa10a2360
3 changed files with 50 additions and 24 deletions

View File

@@ -300,8 +300,11 @@ class Wrapper(Env[ObsType, ActType]):
class ObservationWrapper(Wrapper): class ObservationWrapper(Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
observation = self.env.reset(**kwargs) if kwargs.get("return_info", False):
return self.observation(observation) obs, info = self.env.reset(**kwargs)
return self.observation(obs), info
else:
return self.observation(self.env.reset(**kwargs))
def step(self, action): def step(self, action):
observation, reward, done, info = self.env.step(action) observation, reward, done, info = self.env.step(action)

View File

@@ -1,9 +1,6 @@
from typing import Optional
import numpy as np import numpy as np
import gym import gym
from gym.spaces import Box from gym.spaces import Box
from gym.wrappers import TimeLimit
try: try:
import cv2 import cv2
@@ -45,14 +42,14 @@ class AtariPreprocessing(gym.Wrapper):
def __init__( def __init__(
self, self,
env, env: gym.Env,
noop_max=30, noop_max: int = 30,
frame_skip=4, frame_skip: int = 4,
screen_size=84, screen_size: int = 84,
terminal_on_life_loss=False, terminal_on_life_loss: bool = False,
grayscale_obs=True, grayscale_obs: bool = True,
grayscale_newaxis=False, grayscale_newaxis: bool = False,
scale_obs=False, scale_obs: bool = False,
): ):
super().__init__(env) super().__init__(env)
assert ( assert (
@@ -62,10 +59,14 @@ class AtariPreprocessing(gym.Wrapper):
assert screen_size > 0 assert screen_size > 0
assert noop_max >= 0 assert noop_max >= 0
if frame_skip > 1: if frame_skip > 1:
assert "NoFrameskip" in env.spec.id, ( if (
"disable frame-skipping in the original env. for more than one" "NoFrameskip" not in env.spec.id
" frame-skip as it will be done by the wrapper" and getattr(env.unwrapped, "_frameskip", None) != 1
) ):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one"
" frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP" assert env.unwrapped.get_action_meanings()[0] == "NOOP"
@@ -131,16 +132,26 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
# NoopReset # NoopReset
self.env.reset(**kwargs) if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
noops = ( noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1) self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0 if self.noop_max > 0
else 0 else 0
) )
for _ in range(noops): for _ in range(noops):
_, _, done, _ = self.env.step(0) _, _, done, step_info = self.env.step(0)
reset_info.update(step_info)
if done: if done:
self.env.reset(**kwargs) if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
self.lives = self.ale.lives() self.lives = self.ale.lives()
if self.grayscale_obs: if self.grayscale_obs:
@@ -148,7 +159,11 @@ class AtariPreprocessing(gym.Wrapper):
else: else:
self.ale.getScreenRGB(self.obs_buffer[0]) self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0) self.obs_buffer[1].fill(0)
return self._get_obs()
if kwargs.get("return_info", False):
return self._get_obs(), reset_info
else:
return self._get_obs()
def _get_obs(self): def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling if self.frame_skip > 1: # more efficient in-place pooling

View File

@@ -119,6 +119,14 @@ class FrameStack(ObservationWrapper):
return self.observation(), reward, done, info return self.observation(), reward, done, info
def reset(self, **kwargs): def reset(self, **kwargs):
observation = self.env.reset(**kwargs) if kwargs.get("return_info", False):
[self.frames.append(observation) for _ in range(self.num_stack)] obs, info = self.env.reset(**kwargs)
return self.observation() else:
obs = self.env.reset(**kwargs)
info = None # Unused
[self.frames.append(obs) for _ in range(self.num_stack)]
if kwargs.get("return_info", False):
return self.observation(), info
else:
return self.observation()