Add padding_type parameter to FrameSkipObservation and change default to "reset" to preserve backwards comp (#830)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
Jose Antonio Martin H
2023-12-19 07:53:14 +01:00
committed by GitHub
parent e7e80a9d62
commit d756522eeb
2 changed files with 161 additions and 54 deletions

View File

@@ -299,26 +299,54 @@ class FrameStackObservation(
is an array with shape [3], so if we stack 4 observations, the processed observation
has shape [4, 3].
No vector version of the wrapper exists.
Users have options for the padded observation used:
Note:
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
* "reset" (default) - The reset value is repeated
* "zero" - A "zero"-like instance of the observation space
* custom - An instance of the observation space
No vector version of the wrapper exists.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import FrameStackObservation
>>> env = gym.make("CarRacing-v2")
>>> env = FrameStackObservation(env, 4)
>>> env = FrameStackObservation(env, stack_size=4)
>>> env.observation_space
Box(0, 255, (4, 96, 96, 3), uint8)
>>> obs, _ = env.reset()
>>> obs.shape
(4, 96, 96, 3)
Example with different padding observations:
>>> env = gym.make("CartPole-v1")
>>> env.reset(seed=123)
(array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
>>> stacked_env = FrameStackObservation(env, 3) # the default is padding_type="reset"
>>> stacked_env.reset(seed=123)
(array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
dtype=float32), {})
>>> stacked_env = FrameStackObservation(env, 3, padding_type="zero")
>>> stacked_env.reset(seed=123)
(array([[ 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ],
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
dtype=float32), {})
>>> stacked_env = FrameStackObservation(env, 3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32))
>>> stacked_env.reset(seed=123)
(array([[ 1. , -1. , 0. , 2. ],
[ 1. , -1. , 0. , 2. ],
[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
dtype=float32), {})
Change logs:
* v0.15.0 - Initially add as ``FrameStack`` with support for lz4
* v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support
along with adding the ``padding_type`` parameter
"""
def __init__(
@@ -326,15 +354,20 @@ class FrameStackObservation(
env: gym.Env[ObsType, ActType],
stack_size: int,
*,
zeros_obs: ObsType | None = None,
padding_type: str | ObsType = "reset",
):
"""Observation wrapper that stacks the observations in a rolling manner.
Args:
env: The environment to apply the wrapper
stack_size: The number of frames to stack with zero_obs being used originally.
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
stack_size: The number of frames to stack.
padding_type: The padding type to use when stacking the observations, options: "reset", "zero", custom obs
"""
gym.utils.RecordConstructorArgs.__init__(
self, stack_size=stack_size, padding_type=padding_type
)
gym.Wrapper.__init__(self, env)
if not np.issubdtype(type(stack_size), np.integer):
raise TypeError(
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
@@ -343,22 +376,31 @@ class FrameStackObservation(
raise ValueError(
f"The stack_size needs to be greater than one, actual value: {stack_size}"
)
gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
gym.Wrapper.__init__(self, env)
if isinstance(padding_type, str) and (
padding_type == "reset" or padding_type == "zero"
):
self.padding_value: ObsType = create_zero_array(env.observation_space)
elif padding_type in env.observation_space:
self.padding_value = padding_type
padding_type = "_custom"
else:
if isinstance(padding_type, str):
raise ValueError( # we are guessing that the user just entered the "reset" or "zero" wrong
f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r}"
)
else:
raise ValueError(
f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r} not an instance of env observation ({env.observation_space})"
)
self.observation_space = batch_space(env.observation_space, n=stack_size)
self.stack_size: Final[int] = stack_size
self.padding_type: Final[str] = padding_type
self.zero_obs: Final[ObsType] = (
zeros_obs if zeros_obs else create_zero_array(env.observation_space)
)
self._stacked_obs = deque(
[self.zero_obs for _ in range(self.stack_size)], maxlen=self.stack_size
)
self._stacked_array = create_empty_array(
env.observation_space, n=self.stack_size
self.obs_queue = deque(
[self.padding_value for _ in range(self.stack_size)], maxlen=self.stack_size
)
self.stacked_obs = create_empty_array(env.observation_space, n=self.stack_size)
def step(
self, action: WrapperActType
@@ -371,13 +413,11 @@ class FrameStackObservation(
Returns:
Stacked observations, reward, terminated, truncated, and info from the environment
"""
obs, reward, terminated, truncated, info = super().step(action)
self._stacked_obs.append(obs)
obs, reward, terminated, truncated, info = self.env.step(action)
self.obs_queue.append(obs)
updated_obs = deepcopy(
concatenate(
self.env.observation_space, self._stacked_obs, self._stacked_array
)
concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs)
)
return updated_obs, reward, terminated, truncated, info
@@ -393,15 +433,16 @@ class FrameStackObservation(
Returns:
The stacked observations and info
"""
obs, info = super().reset(seed=seed, options=options)
obs, info = self.env.reset(seed=seed, options=options)
if self.padding_type == "reset":
self.padding_value = obs
for _ in range(self.stack_size - 1):
self._stacked_obs.append(self.zero_obs)
self._stacked_obs.append(obs)
self.obs_queue.append(self.padding_value)
self.obs_queue.append(obs)
updated_obs = deepcopy(
concatenate(
self.env.observation_space, self._stacked_obs, self._stacked_array
)
concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs)
)
return updated_obs, info

View File

@@ -1,32 +1,29 @@
"""Test suite for FrameStackObservation wrapper."""
import re
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector.utils import iterate
from gymnasium.wrappers import FrameStackObservation
from gymnasium.wrappers.utils import create_zero_array
from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS
@pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS)
def test_env_obs(env, stack_size: int = 3):
"""Test different environment observations for testing."""
def test_different_obs_spaces(env, stack_size: int = 3):
"""Test across a large number of observation spaces to check if the FrameStack wrapper ."""
obs, _ = env.reset(seed=SEED)
env.action_space.seed(SEED)
unstacked_obs = [
create_zero_array(env.observation_space) for _ in range(stack_size - 1)
]
unstacked_obs.append(obs)
unstacked_obs = [obs for _ in range(stack_size)]
for _ in range(stack_size * 2):
obs, _, _, _, _ = env.step(env.action_space.sample())
unstacked_obs.append(obs)
env = FrameStackObservation(env, stack_size=stack_size)
env.action_space.seed(SEED)
env.action_space.seed(seed=SEED)
obs, _ = env.reset(seed=SEED)
stacked_obs = [obs]
@@ -50,25 +47,76 @@ def test_stack_size(stack_size: int):
"""Test different stack sizes for FrameStackObservation wrapper."""
env = gym.make("CartPole-v1")
env.action_space.seed(seed=SEED)
first_obs, _ = env.reset(seed=SEED)
second_obs, _, _, _, _ = env.step(env.action_space.sample())
zero_obs = create_zero_array(env.observation_space)
# Perform a series of actions and store the resulting observations
unstacked_obs = []
obs, _ = env.reset(seed=SEED)
unstacked_obs.append(obs)
first_obs = obs # Store the first observation
for _ in range(5):
obs, _, _, _, _ = env.step(env.action_space.sample())
unstacked_obs.append(obs)
env = FrameStackObservation(env, stack_size=stack_size)
env.action_space.seed(seed=SEED)
obs, _ = env.reset(seed=SEED)
unstacked_obs = list(iterate(env.observation_space, obs))
assert len(unstacked_obs) == stack_size
assert data_equivalence(
[zero_obs for _ in range(stack_size - 1)], unstacked_obs[:-1]
)
assert data_equivalence(first_obs, unstacked_obs[-1])
obs, _, _, _, _ = env.step(env.action_space.sample())
unstacked_obs = list(iterate(env.observation_space, obs))
assert data_equivalence(second_obs, unstacked_obs[-1])
# Perform the same series of actions and store the resulting stacked observations
stacked_obs = []
obs, _ = env.reset(seed=SEED)
stacked_obs.append(obs)
for _ in range(5):
obs, _, _, _, _ = env.step(env.action_space.sample())
stacked_obs.append(obs)
# Check that the frames in each stacked observation match the corresponding observations
for i in range(len(stacked_obs)):
frames = list(iterate(env.observation_space, stacked_obs[i]))
for j in range(stack_size):
if i - j < 0:
# Use the first observation instead of a zero observation
expected_obs = first_obs
else:
expected_obs = unstacked_obs[i - j]
assert data_equivalence(expected_obs, frames[stack_size - 1 - j])
def test_padding_type():
env = gym.make("CartPole-v1")
reset_obs, _ = env.reset(seed=123)
action = env.action_space.sample()
step_obs, _, _, _, _ = env.step(action)
stacked_env = FrameStackObservation(env, stack_size=3) # default = "reset"
stacked_obs, _ = stacked_env.reset(seed=123)
assert np.all(np.stack([reset_obs, reset_obs, reset_obs]) == stacked_obs)
stacked_obs, _, _, _, _ = stacked_env.step(action)
assert np.all(np.stack([reset_obs, reset_obs, step_obs]) == stacked_obs)
stacked_env = FrameStackObservation(env, stack_size=3, padding_type="zero")
stacked_obs, _ = stacked_env.reset(seed=123)
assert np.all(np.stack([np.zeros(4), np.zeros(4), reset_obs]) == stacked_obs)
stacked_obs, _, _, _, _ = stacked_env.step(action)
assert np.all(np.stack([np.zeros(4), reset_obs, step_obs]) == stacked_obs)
stacked_env = FrameStackObservation(
env, stack_size=3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32)
)
stacked_obs, _ = stacked_env.reset(seed=123)
assert np.all(
np.stack(
[
np.array([1, -1, 0, 2], dtype=np.float32),
np.array([1, -1, 0, 2], dtype=np.float32),
reset_obs,
]
)
== stacked_obs
)
stacked_obs, _, _, _, _ = stacked_env.step(action)
assert np.all(
np.stack([np.array([1, -1, 0, 2], dtype=np.float32), reset_obs, step_obs])
== stacked_obs
)
def test_stack_size_failures():
@@ -85,6 +133,24 @@ def test_stack_size_failures():
with pytest.raises(
ValueError,
match=re.escape("The stack_size needs to be greater than one, actual value: 0"),
match=re.escape("The stack_size needs to be greater than one, actual value: 1"),
):
FrameStackObservation(env, stack_size=0)
FrameStackObservation(env, stack_size=1)
with pytest.raises(
ValueError,
match=re.escape(
"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: 'unknown'"
),
):
FrameStackObservation(env, stack_size=3, padding_type="unknown")
invalid_padding = np.array([1, 2, 3, 4, 5])
assert invalid_padding not in env.observation_space
with pytest.raises(
ValueError,
match=re.escape(
"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: array([1, 2, 3, 4, 5])"
),
):
FrameStackObservation(env, stack_size=3, padding_type=invalid_padding)