Fix return_info for wrappers (#2612)

* Fix `return_info` for Observation wrappers, Atari (?) and framestack

* Make type checkers and IDEs happier

* Merge in #2454

* Update the info dict based on no-op steps
Some type hints

* Bug fix

* Handle resets during frameskip
This commit is contained in:
Ariel Kwiatkowski
2022-02-17 18:03:35 +01:00
committed by GitHub
parent 8dc1b52d21
commit 3fa10a2360
3 changed files with 50 additions and 24 deletions

View File

@@ -300,8 +300,11 @@ class Wrapper(Env[ObsType, ActType]):
class ObservationWrapper(Wrapper):
def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
return self.observation(observation)
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
return self.observation(obs), info
else:
return self.observation(self.env.reset(**kwargs))
def step(self, action):
observation, reward, done, info = self.env.step(action)

View File

@@ -1,9 +1,6 @@
from typing import Optional
import numpy as np
import gym
from gym.spaces import Box
from gym.wrappers import TimeLimit
try:
import cv2
@@ -45,14 +42,14 @@ class AtariPreprocessing(gym.Wrapper):
def __init__(
self,
env,
noop_max=30,
frame_skip=4,
screen_size=84,
terminal_on_life_loss=False,
grayscale_obs=True,
grayscale_newaxis=False,
scale_obs=False,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
):
super().__init__(env)
assert (
@@ -62,10 +59,14 @@ class AtariPreprocessing(gym.Wrapper):
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1:
assert "NoFrameskip" in env.spec.id, (
"disable frame-skipping in the original env. for more than one"
" frame-skip as it will be done by the wrapper"
)
if (
"NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one"
" frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
@@ -131,16 +132,26 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs):
# NoopReset
self.env.reset(**kwargs)
if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, done, _ = self.env.step(0)
_, _, done, step_info = self.env.step(0)
reset_info.update(step_info)
if done:
self.env.reset(**kwargs)
if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
self.lives = self.ale.lives()
if self.grayscale_obs:
@@ -148,7 +159,11 @@ class AtariPreprocessing(gym.Wrapper):
else:
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
return self._get_obs()
if kwargs.get("return_info", False):
return self._get_obs(), reset_info
else:
return self._get_obs()
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling

View File

@@ -119,6 +119,14 @@ class FrameStack(ObservationWrapper):
return self.observation(), reward, done, info
def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
[self.frames.append(observation) for _ in range(self.num_stack)]
return self.observation()
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
else:
obs = self.env.reset(**kwargs)
info = None # Unused
[self.frames.append(obs) for _ in range(self.num_stack)]
if kwargs.get("return_info", False):
return self.observation(), info
else:
return self.observation()