mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Fix return_info
for wrappers (#2612)
* Fix `return_info` for Observation wrappers, Atari (?) and framestack * Make type checkers and IDEs happier * Merge in #2454 * Update the info dict based on no-op steps Some type hints * Bug fix * Handle resets during frameskip
This commit is contained in:
committed by
GitHub
parent
8dc1b52d21
commit
3fa10a2360
@@ -300,8 +300,11 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
|
||||
class ObservationWrapper(Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
observation = self.env.reset(**kwargs)
|
||||
return self.observation(observation)
|
||||
if kwargs.get("return_info", False):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
return self.observation(obs), info
|
||||
else:
|
||||
return self.observation(self.env.reset(**kwargs))
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
|
@@ -1,9 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
from gym.wrappers import TimeLimit
|
||||
|
||||
try:
|
||||
import cv2
|
||||
@@ -45,14 +42,14 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
noop_max=30,
|
||||
frame_skip=4,
|
||||
screen_size=84,
|
||||
terminal_on_life_loss=False,
|
||||
grayscale_obs=True,
|
||||
grayscale_newaxis=False,
|
||||
scale_obs=False,
|
||||
env: gym.Env,
|
||||
noop_max: int = 30,
|
||||
frame_skip: int = 4,
|
||||
screen_size: int = 84,
|
||||
terminal_on_life_loss: bool = False,
|
||||
grayscale_obs: bool = True,
|
||||
grayscale_newaxis: bool = False,
|
||||
scale_obs: bool = False,
|
||||
):
|
||||
super().__init__(env)
|
||||
assert (
|
||||
@@ -62,10 +59,14 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
assert screen_size > 0
|
||||
assert noop_max >= 0
|
||||
if frame_skip > 1:
|
||||
assert "NoFrameskip" in env.spec.id, (
|
||||
"disable frame-skipping in the original env. for more than one"
|
||||
" frame-skip as it will be done by the wrapper"
|
||||
)
|
||||
if (
|
||||
"NoFrameskip" not in env.spec.id
|
||||
and getattr(env.unwrapped, "_frameskip", None) != 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Disable frame-skipping in the original env. Otherwise, more than one"
|
||||
" frame-skip will happen as through this wrapper"
|
||||
)
|
||||
self.noop_max = noop_max
|
||||
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||
|
||||
@@ -131,16 +132,26 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
|
||||
def reset(self, **kwargs):
|
||||
# NoopReset
|
||||
self.env.reset(**kwargs)
|
||||
if kwargs.get("return_info", False):
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
else:
|
||||
_ = self.env.reset(**kwargs)
|
||||
reset_info = {}
|
||||
|
||||
noops = (
|
||||
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
if self.noop_max > 0
|
||||
else 0
|
||||
)
|
||||
for _ in range(noops):
|
||||
_, _, done, _ = self.env.step(0)
|
||||
_, _, done, step_info = self.env.step(0)
|
||||
reset_info.update(step_info)
|
||||
if done:
|
||||
self.env.reset(**kwargs)
|
||||
if kwargs.get("return_info", False):
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
else:
|
||||
_ = self.env.reset(**kwargs)
|
||||
reset_info = {}
|
||||
|
||||
self.lives = self.ale.lives()
|
||||
if self.grayscale_obs:
|
||||
@@ -148,7 +159,11 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
else:
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
self.obs_buffer[1].fill(0)
|
||||
return self._get_obs()
|
||||
|
||||
if kwargs.get("return_info", False):
|
||||
return self._get_obs(), reset_info
|
||||
else:
|
||||
return self._get_obs()
|
||||
|
||||
def _get_obs(self):
|
||||
if self.frame_skip > 1: # more efficient in-place pooling
|
||||
|
@@ -119,6 +119,14 @@ class FrameStack(ObservationWrapper):
|
||||
return self.observation(), reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
observation = self.env.reset(**kwargs)
|
||||
[self.frames.append(observation) for _ in range(self.num_stack)]
|
||||
return self.observation()
|
||||
if kwargs.get("return_info", False):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
else:
|
||||
obs = self.env.reset(**kwargs)
|
||||
info = None # Unused
|
||||
[self.frames.append(obs) for _ in range(self.num_stack)]
|
||||
|
||||
if kwargs.get("return_info", False):
|
||||
return self.observation(), info
|
||||
else:
|
||||
return self.observation()
|
||||
|
Reference in New Issue
Block a user