mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Add padding_type
parameter to FrameSkipObservation
and change default to "reset" to preserve backwards comp (#830)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e7e80a9d62
commit
d756522eeb
@@ -299,26 +299,54 @@ class FrameStackObservation(
|
||||
is an array with shape [3], so if we stack 4 observations, the processed observation
|
||||
has shape [4, 3].
|
||||
|
||||
No vector version of the wrapper exists.
|
||||
Users have options for the padded observation used:
|
||||
|
||||
Note:
|
||||
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
|
||||
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
|
||||
* "reset" (default) - The reset value is repeated
|
||||
* "zero" - A "zero"-like instance of the observation space
|
||||
* custom - An instance of the observation space
|
||||
|
||||
No vector version of the wrapper exists.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import FrameStackObservation
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env = FrameStackObservation(env, 4)
|
||||
>>> env = FrameStackObservation(env, stack_size=4)
|
||||
>>> env.observation_space
|
||||
Box(0, 255, (4, 96, 96, 3), uint8)
|
||||
>>> obs, _ = env.reset()
|
||||
>>> obs.shape
|
||||
(4, 96, 96, 3)
|
||||
|
||||
Example with different padding observations:
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env.reset(seed=123)
|
||||
(array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
|
||||
>>> stacked_env = FrameStackObservation(env, 3) # the default is padding_type="reset"
|
||||
>>> stacked_env.reset(seed=123)
|
||||
(array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
|
||||
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
|
||||
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
|
||||
dtype=float32), {})
|
||||
|
||||
|
||||
>>> stacked_env = FrameStackObservation(env, 3, padding_type="zero")
|
||||
>>> stacked_env.reset(seed=123)
|
||||
(array([[ 0. , 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. , 0. ],
|
||||
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
|
||||
dtype=float32), {})
|
||||
>>> stacked_env = FrameStackObservation(env, 3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32))
|
||||
>>> stacked_env.reset(seed=123)
|
||||
(array([[ 1. , -1. , 0. , 2. ],
|
||||
[ 1. , -1. , 0. , 2. ],
|
||||
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
|
||||
dtype=float32), {})
|
||||
|
||||
Change logs:
|
||||
* v0.15.0 - Initially add as ``FrameStack`` with support for lz4
|
||||
* v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support
|
||||
along with adding the ``padding_type`` parameter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -326,15 +354,20 @@ class FrameStackObservation(
|
||||
env: gym.Env[ObsType, ActType],
|
||||
stack_size: int,
|
||||
*,
|
||||
zeros_obs: ObsType | None = None,
|
||||
padding_type: str | ObsType = "reset",
|
||||
):
|
||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
stack_size: The number of frames to stack with zero_obs being used originally.
|
||||
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
|
||||
stack_size: The number of frames to stack.
|
||||
padding_type: The padding type to use when stacking the observations, options: "reset", "zero", custom obs
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, stack_size=stack_size, padding_type=padding_type
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if not np.issubdtype(type(stack_size), np.integer):
|
||||
raise TypeError(
|
||||
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
|
||||
@@ -343,22 +376,31 @@ class FrameStackObservation(
|
||||
raise ValueError(
|
||||
f"The stack_size needs to be greater than one, actual value: {stack_size}"
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
if isinstance(padding_type, str) and (
|
||||
padding_type == "reset" or padding_type == "zero"
|
||||
):
|
||||
self.padding_value: ObsType = create_zero_array(env.observation_space)
|
||||
elif padding_type in env.observation_space:
|
||||
self.padding_value = padding_type
|
||||
padding_type = "_custom"
|
||||
else:
|
||||
if isinstance(padding_type, str):
|
||||
raise ValueError( # we are guessing that the user just entered the "reset" or "zero" wrong
|
||||
f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r} not an instance of env observation ({env.observation_space})"
|
||||
)
|
||||
|
||||
self.observation_space = batch_space(env.observation_space, n=stack_size)
|
||||
self.stack_size: Final[int] = stack_size
|
||||
self.padding_type: Final[str] = padding_type
|
||||
|
||||
self.zero_obs: Final[ObsType] = (
|
||||
zeros_obs if zeros_obs else create_zero_array(env.observation_space)
|
||||
)
|
||||
self._stacked_obs = deque(
|
||||
[self.zero_obs for _ in range(self.stack_size)], maxlen=self.stack_size
|
||||
)
|
||||
self._stacked_array = create_empty_array(
|
||||
env.observation_space, n=self.stack_size
|
||||
self.obs_queue = deque(
|
||||
[self.padding_value for _ in range(self.stack_size)], maxlen=self.stack_size
|
||||
)
|
||||
self.stacked_obs = create_empty_array(env.observation_space, n=self.stack_size)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
@@ -371,13 +413,11 @@ class FrameStackObservation(
|
||||
Returns:
|
||||
Stacked observations, reward, terminated, truncated, and info from the environment
|
||||
"""
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
self._stacked_obs.append(obs)
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
self.obs_queue.append(obs)
|
||||
|
||||
updated_obs = deepcopy(
|
||||
concatenate(
|
||||
self.env.observation_space, self._stacked_obs, self._stacked_array
|
||||
)
|
||||
concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs)
|
||||
)
|
||||
return updated_obs, reward, terminated, truncated, info
|
||||
|
||||
@@ -393,15 +433,16 @@ class FrameStackObservation(
|
||||
Returns:
|
||||
The stacked observations and info
|
||||
"""
|
||||
obs, info = super().reset(seed=seed, options=options)
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
|
||||
if self.padding_type == "reset":
|
||||
self.padding_value = obs
|
||||
for _ in range(self.stack_size - 1):
|
||||
self._stacked_obs.append(self.zero_obs)
|
||||
self._stacked_obs.append(obs)
|
||||
self.obs_queue.append(self.padding_value)
|
||||
self.obs_queue.append(obs)
|
||||
|
||||
updated_obs = deepcopy(
|
||||
concatenate(
|
||||
self.env.observation_space, self._stacked_obs, self._stacked_array
|
||||
)
|
||||
concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs)
|
||||
)
|
||||
return updated_obs, info
|
||||
|
||||
|
@@ -1,32 +1,29 @@
|
||||
"""Test suite for FrameStackObservation wrapper."""
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from gymnasium.vector.utils import iterate
|
||||
from gymnasium.wrappers import FrameStackObservation
|
||||
from gymnasium.wrappers.utils import create_zero_array
|
||||
from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS)
|
||||
def test_env_obs(env, stack_size: int = 3):
|
||||
"""Test different environment observations for testing."""
|
||||
def test_different_obs_spaces(env, stack_size: int = 3):
|
||||
"""Test across a large number of observation spaces to check if the FrameStack wrapper ."""
|
||||
obs, _ = env.reset(seed=SEED)
|
||||
env.action_space.seed(SEED)
|
||||
|
||||
unstacked_obs = [
|
||||
create_zero_array(env.observation_space) for _ in range(stack_size - 1)
|
||||
]
|
||||
unstacked_obs.append(obs)
|
||||
unstacked_obs = [obs for _ in range(stack_size)]
|
||||
for _ in range(stack_size * 2):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
unstacked_obs.append(obs)
|
||||
|
||||
env = FrameStackObservation(env, stack_size=stack_size)
|
||||
env.action_space.seed(SEED)
|
||||
env.action_space.seed(seed=SEED)
|
||||
|
||||
obs, _ = env.reset(seed=SEED)
|
||||
stacked_obs = [obs]
|
||||
@@ -50,25 +47,76 @@ def test_stack_size(stack_size: int):
|
||||
"""Test different stack sizes for FrameStackObservation wrapper."""
|
||||
env = gym.make("CartPole-v1")
|
||||
env.action_space.seed(seed=SEED)
|
||||
first_obs, _ = env.reset(seed=SEED)
|
||||
second_obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
|
||||
zero_obs = create_zero_array(env.observation_space)
|
||||
# Perform a series of actions and store the resulting observations
|
||||
unstacked_obs = []
|
||||
obs, _ = env.reset(seed=SEED)
|
||||
unstacked_obs.append(obs)
|
||||
first_obs = obs # Store the first observation
|
||||
for _ in range(5):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
unstacked_obs.append(obs)
|
||||
|
||||
env = FrameStackObservation(env, stack_size=stack_size)
|
||||
|
||||
env.action_space.seed(seed=SEED)
|
||||
obs, _ = env.reset(seed=SEED)
|
||||
unstacked_obs = list(iterate(env.observation_space, obs))
|
||||
assert len(unstacked_obs) == stack_size
|
||||
assert data_equivalence(
|
||||
[zero_obs for _ in range(stack_size - 1)], unstacked_obs[:-1]
|
||||
)
|
||||
assert data_equivalence(first_obs, unstacked_obs[-1])
|
||||
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
unstacked_obs = list(iterate(env.observation_space, obs))
|
||||
assert data_equivalence(second_obs, unstacked_obs[-1])
|
||||
# Perform the same series of actions and store the resulting stacked observations
|
||||
stacked_obs = []
|
||||
obs, _ = env.reset(seed=SEED)
|
||||
stacked_obs.append(obs)
|
||||
for _ in range(5):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
stacked_obs.append(obs)
|
||||
|
||||
# Check that the frames in each stacked observation match the corresponding observations
|
||||
for i in range(len(stacked_obs)):
|
||||
frames = list(iterate(env.observation_space, stacked_obs[i]))
|
||||
for j in range(stack_size):
|
||||
if i - j < 0:
|
||||
# Use the first observation instead of a zero observation
|
||||
expected_obs = first_obs
|
||||
else:
|
||||
expected_obs = unstacked_obs[i - j]
|
||||
assert data_equivalence(expected_obs, frames[stack_size - 1 - j])
|
||||
|
||||
|
||||
def test_padding_type():
|
||||
env = gym.make("CartPole-v1")
|
||||
reset_obs, _ = env.reset(seed=123)
|
||||
action = env.action_space.sample()
|
||||
step_obs, _, _, _, _ = env.step(action)
|
||||
|
||||
stacked_env = FrameStackObservation(env, stack_size=3) # default = "reset"
|
||||
stacked_obs, _ = stacked_env.reset(seed=123)
|
||||
assert np.all(np.stack([reset_obs, reset_obs, reset_obs]) == stacked_obs)
|
||||
stacked_obs, _, _, _, _ = stacked_env.step(action)
|
||||
assert np.all(np.stack([reset_obs, reset_obs, step_obs]) == stacked_obs)
|
||||
|
||||
stacked_env = FrameStackObservation(env, stack_size=3, padding_type="zero")
|
||||
stacked_obs, _ = stacked_env.reset(seed=123)
|
||||
assert np.all(np.stack([np.zeros(4), np.zeros(4), reset_obs]) == stacked_obs)
|
||||
stacked_obs, _, _, _, _ = stacked_env.step(action)
|
||||
assert np.all(np.stack([np.zeros(4), reset_obs, step_obs]) == stacked_obs)
|
||||
|
||||
stacked_env = FrameStackObservation(
|
||||
env, stack_size=3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32)
|
||||
)
|
||||
stacked_obs, _ = stacked_env.reset(seed=123)
|
||||
assert np.all(
|
||||
np.stack(
|
||||
[
|
||||
np.array([1, -1, 0, 2], dtype=np.float32),
|
||||
np.array([1, -1, 0, 2], dtype=np.float32),
|
||||
reset_obs,
|
||||
]
|
||||
)
|
||||
== stacked_obs
|
||||
)
|
||||
stacked_obs, _, _, _, _ = stacked_env.step(action)
|
||||
assert np.all(
|
||||
np.stack([np.array([1, -1, 0, 2], dtype=np.float32), reset_obs, step_obs])
|
||||
== stacked_obs
|
||||
)
|
||||
|
||||
|
||||
def test_stack_size_failures():
|
||||
@@ -85,6 +133,24 @@ def test_stack_size_failures():
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("The stack_size needs to be greater than one, actual value: 0"),
|
||||
match=re.escape("The stack_size needs to be greater than one, actual value: 1"),
|
||||
):
|
||||
FrameStackObservation(env, stack_size=0)
|
||||
FrameStackObservation(env, stack_size=1)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: 'unknown'"
|
||||
),
|
||||
):
|
||||
FrameStackObservation(env, stack_size=3, padding_type="unknown")
|
||||
|
||||
invalid_padding = np.array([1, 2, 3, 4, 5])
|
||||
assert invalid_padding not in env.observation_space
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: array([1, 2, 3, 4, 5])"
|
||||
),
|
||||
):
|
||||
FrameStackObservation(env, stack_size=3, padding_type=invalid_padding)
|
||||
|
Reference in New Issue
Block a user