mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-16 19:49:13 +00:00
Fix return_info
for wrappers (#2612)
* Fix `return_info` for Observation wrappers, Atari (?) and framestack * Make type checkers and IDEs happier * Merge in #2454 * Update the info dict based on no-op steps Some type hints * Bug fix * Handle resets during frameskip
This commit is contained in:
committed by
GitHub
parent
8dc1b52d21
commit
3fa10a2360
@@ -300,8 +300,11 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
|
|
||||||
class ObservationWrapper(Wrapper):
|
class ObservationWrapper(Wrapper):
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
observation = self.env.reset(**kwargs)
|
if kwargs.get("return_info", False):
|
||||||
return self.observation(observation)
|
obs, info = self.env.reset(**kwargs)
|
||||||
|
return self.observation(obs), info
|
||||||
|
else:
|
||||||
|
return self.observation(self.env.reset(**kwargs))
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
|
@@ -1,9 +1,6 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
from gym.wrappers import TimeLimit
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
@@ -45,14 +42,14 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env,
|
env: gym.Env,
|
||||||
noop_max=30,
|
noop_max: int = 30,
|
||||||
frame_skip=4,
|
frame_skip: int = 4,
|
||||||
screen_size=84,
|
screen_size: int = 84,
|
||||||
terminal_on_life_loss=False,
|
terminal_on_life_loss: bool = False,
|
||||||
grayscale_obs=True,
|
grayscale_obs: bool = True,
|
||||||
grayscale_newaxis=False,
|
grayscale_newaxis: bool = False,
|
||||||
scale_obs=False,
|
scale_obs: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
assert (
|
assert (
|
||||||
@@ -62,10 +59,14 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
assert screen_size > 0
|
assert screen_size > 0
|
||||||
assert noop_max >= 0
|
assert noop_max >= 0
|
||||||
if frame_skip > 1:
|
if frame_skip > 1:
|
||||||
assert "NoFrameskip" in env.spec.id, (
|
if (
|
||||||
"disable frame-skipping in the original env. for more than one"
|
"NoFrameskip" not in env.spec.id
|
||||||
" frame-skip as it will be done by the wrapper"
|
and getattr(env.unwrapped, "_frameskip", None) != 1
|
||||||
)
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Disable frame-skipping in the original env. Otherwise, more than one"
|
||||||
|
" frame-skip will happen as through this wrapper"
|
||||||
|
)
|
||||||
self.noop_max = noop_max
|
self.noop_max = noop_max
|
||||||
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||||
|
|
||||||
@@ -131,16 +132,26 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
# NoopReset
|
# NoopReset
|
||||||
self.env.reset(**kwargs)
|
if kwargs.get("return_info", False):
|
||||||
|
_, reset_info = self.env.reset(**kwargs)
|
||||||
|
else:
|
||||||
|
_ = self.env.reset(**kwargs)
|
||||||
|
reset_info = {}
|
||||||
|
|
||||||
noops = (
|
noops = (
|
||||||
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||||
if self.noop_max > 0
|
if self.noop_max > 0
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
for _ in range(noops):
|
for _ in range(noops):
|
||||||
_, _, done, _ = self.env.step(0)
|
_, _, done, step_info = self.env.step(0)
|
||||||
|
reset_info.update(step_info)
|
||||||
if done:
|
if done:
|
||||||
self.env.reset(**kwargs)
|
if kwargs.get("return_info", False):
|
||||||
|
_, reset_info = self.env.reset(**kwargs)
|
||||||
|
else:
|
||||||
|
_ = self.env.reset(**kwargs)
|
||||||
|
reset_info = {}
|
||||||
|
|
||||||
self.lives = self.ale.lives()
|
self.lives = self.ale.lives()
|
||||||
if self.grayscale_obs:
|
if self.grayscale_obs:
|
||||||
@@ -148,7 +159,11 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
else:
|
else:
|
||||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
self.obs_buffer[1].fill(0)
|
self.obs_buffer[1].fill(0)
|
||||||
return self._get_obs()
|
|
||||||
|
if kwargs.get("return_info", False):
|
||||||
|
return self._get_obs(), reset_info
|
||||||
|
else:
|
||||||
|
return self._get_obs()
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
if self.frame_skip > 1: # more efficient in-place pooling
|
if self.frame_skip > 1: # more efficient in-place pooling
|
||||||
|
@@ -119,6 +119,14 @@ class FrameStack(ObservationWrapper):
|
|||||||
return self.observation(), reward, done, info
|
return self.observation(), reward, done, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
observation = self.env.reset(**kwargs)
|
if kwargs.get("return_info", False):
|
||||||
[self.frames.append(observation) for _ in range(self.num_stack)]
|
obs, info = self.env.reset(**kwargs)
|
||||||
return self.observation()
|
else:
|
||||||
|
obs = self.env.reset(**kwargs)
|
||||||
|
info = None # Unused
|
||||||
|
[self.frames.append(obs) for _ in range(self.num_stack)]
|
||||||
|
|
||||||
|
if kwargs.get("return_info", False):
|
||||||
|
return self.observation(), info
|
||||||
|
else:
|
||||||
|
return self.observation()
|
||||||
|
Reference in New Issue
Block a user