mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
Wrapper to discretize observations and actions (#1411)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
@@ -76,11 +76,13 @@ from gymnasium.wrappers.stateful_observation import (
|
|||||||
from gymnasium.wrappers.stateful_reward import NormalizeReward
|
from gymnasium.wrappers.stateful_reward import NormalizeReward
|
||||||
from gymnasium.wrappers.transform_action import (
|
from gymnasium.wrappers.transform_action import (
|
||||||
ClipAction,
|
ClipAction,
|
||||||
|
DiscretizeAction,
|
||||||
RescaleAction,
|
RescaleAction,
|
||||||
TransformAction,
|
TransformAction,
|
||||||
)
|
)
|
||||||
from gymnasium.wrappers.transform_observation import (
|
from gymnasium.wrappers.transform_observation import (
|
||||||
AddRenderObservation,
|
AddRenderObservation,
|
||||||
|
DiscretizeObservation,
|
||||||
DtypeObservation,
|
DtypeObservation,
|
||||||
FilterObservation,
|
FilterObservation,
|
||||||
FlattenObservation,
|
FlattenObservation,
|
||||||
@@ -99,6 +101,7 @@ __all__ = [
|
|||||||
"AtariPreprocessing",
|
"AtariPreprocessing",
|
||||||
"DelayObservation",
|
"DelayObservation",
|
||||||
"DtypeObservation",
|
"DtypeObservation",
|
||||||
|
"DiscretizeObservation",
|
||||||
"FilterObservation",
|
"FilterObservation",
|
||||||
"FlattenObservation",
|
"FlattenObservation",
|
||||||
"FrameStackObservation",
|
"FrameStackObservation",
|
||||||
@@ -113,6 +116,7 @@ __all__ = [
|
|||||||
"TimeAwareObservation",
|
"TimeAwareObservation",
|
||||||
# --- Action Wrappers ---
|
# --- Action Wrappers ---
|
||||||
"ClipAction",
|
"ClipAction",
|
||||||
|
"DiscretizeAction",
|
||||||
"TransformAction",
|
"TransformAction",
|
||||||
"RescaleAction",
|
"RescaleAction",
|
||||||
# "NanAction",
|
# "NanAction",
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
* ``TransformAction`` - Transforms the actions based on a function
|
* ``TransformAction`` - Transforms the actions based on a function
|
||||||
* ``ClipAction`` - Clips the action within a bounds
|
* ``ClipAction`` - Clips the action within a bounds
|
||||||
|
* ``DiscretizeAction`` - Discretizes a continuous Box action space into a single Discrete space
|
||||||
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
|
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ import numpy as np
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import ActType, ObsType, WrapperActType
|
from gymnasium.core import ActType, ObsType, WrapperActType
|
||||||
from gymnasium.spaces import Box, Space
|
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Space
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["TransformAction", "ClipAction", "RescaleAction"]
|
__all__ = ["TransformAction", "ClipAction", "RescaleAction"]
|
||||||
@@ -178,3 +179,147 @@ class RescaleAction(
|
|||||||
func=func,
|
func=func,
|
||||||
action_space=act_space,
|
action_space=act_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscretizeAction(
|
||||||
|
TransformAction[ObsType, WrapperActType, ActType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
|
"""Uniformly discretizes a continuous Box action space into a single Discrete space.
|
||||||
|
|
||||||
|
Example 1 - Discretize Pendulum action space:
|
||||||
|
>>> env = gym.make("Pendulum-v1")
|
||||||
|
>>> env.action_space
|
||||||
|
Box(-2.0, 2.0, (1,), float32)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs, *_ = env.step([-0.6])
|
||||||
|
>>> obs
|
||||||
|
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
|
||||||
|
>>> env = DiscretizeAction(env, bins=10)
|
||||||
|
>>> env.action_space
|
||||||
|
Discrete(10)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs, *_ = env.step(3)
|
||||||
|
>>> obs
|
||||||
|
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
|
||||||
|
|
||||||
|
Example 2 - Discretize Reacher action space:
|
||||||
|
>>> env = gym.make("Reacher-v5")
|
||||||
|
>>> env.action_space
|
||||||
|
Box(-1.0, 1.0, (2,), float32)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs, *_ = env.step([-0.3, -0.5])
|
||||||
|
>>> obs
|
||||||
|
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||||
|
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
|
||||||
|
>>> env = DiscretizeAction(env, bins=10)
|
||||||
|
>>> env.action_space
|
||||||
|
Discrete(100)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs, *_ = env.step(32)
|
||||||
|
>>> obs
|
||||||
|
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||||
|
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
|
||||||
|
|
||||||
|
Example 2 - Discretize Reacher action space with MultiDiscrete:
|
||||||
|
>>> env = gym.make("Reacher-v5")
|
||||||
|
>>> env.action_space
|
||||||
|
Box(-1.0, 1.0, (2,), float32)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs, *_ = env.step([-0.3, -0.5])
|
||||||
|
>>> obs
|
||||||
|
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||||
|
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
|
||||||
|
>>> env = DiscretizeAction(env, bins=10, multidiscrete=True)
|
||||||
|
>>> env.action_space
|
||||||
|
MultiDiscrete([10 10])
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs, *_ = env.step([3, 2])
|
||||||
|
>>> obs
|
||||||
|
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||||
|
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env[ObsType, ActType],
|
||||||
|
bins: int | tuple[int, ...],
|
||||||
|
multidiscrete: bool = False,
|
||||||
|
):
|
||||||
|
"""Constructor for the discretize action wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap.
|
||||||
|
bins: int or tuple of ints (number of bins per dimension).
|
||||||
|
multidiscrete: If True, use MultiDiscrete action space instead of flattening to Discrete.
|
||||||
|
"""
|
||||||
|
if not isinstance(env.action_space, Box):
|
||||||
|
raise TypeError(
|
||||||
|
"DiscretizeAction is only compatible with Box continuous actions."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.low = env.action_space.low
|
||||||
|
self.high = env.action_space.high
|
||||||
|
self.n_dims = self.low.shape[0]
|
||||||
|
|
||||||
|
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
|
||||||
|
raise ValueError(
|
||||||
|
"Discretization requires action space to be finite. "
|
||||||
|
f"Found: low={self.low}, high={self.high}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.multidiscrete = multidiscrete
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
|
||||||
|
gym.ActionWrapper.__init__(self, env)
|
||||||
|
|
||||||
|
if isinstance(bins, int):
|
||||||
|
self.bins = np.array([bins] * self.n_dims)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
len(bins) == self.n_dims
|
||||||
|
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
|
||||||
|
self.bins = np.array(bins)
|
||||||
|
|
||||||
|
self.bin_centers = [
|
||||||
|
0.5
|
||||||
|
* (
|
||||||
|
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[:-1]
|
||||||
|
+ np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:]
|
||||||
|
)
|
||||||
|
for i in range(self.n_dims)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.multidiscrete:
|
||||||
|
self.action_space = MultiDiscrete(self.bins)
|
||||||
|
else:
|
||||||
|
self.action_space = Discrete(np.prod(self.bins))
|
||||||
|
|
||||||
|
def action(self, act):
|
||||||
|
"""Discretizes the action."""
|
||||||
|
if self.multidiscrete:
|
||||||
|
indices = np.asarray(act, dtype=int)
|
||||||
|
else:
|
||||||
|
indices = self._unflatten_index(act)
|
||||||
|
centers = [
|
||||||
|
self.bin_centers[i][min(max(idx, 0), self.bins[i] - 1)]
|
||||||
|
for i, idx in enumerate(indices)
|
||||||
|
]
|
||||||
|
return np.array(centers, dtype=self.env.action_space.dtype)
|
||||||
|
|
||||||
|
def revert_action(self, action):
|
||||||
|
"""Converts a discretized action to a possible continuous action (the center of the closest bin)."""
|
||||||
|
indices = [
|
||||||
|
np.argmin(np.abs(self.bin_centers[i] - action[i]))
|
||||||
|
for i in range(self.n_dims)
|
||||||
|
]
|
||||||
|
if self.multidiscrete:
|
||||||
|
return np.array(indices, dtype=int)
|
||||||
|
else:
|
||||||
|
return np.ravel_multi_index(indices, self.bins)
|
||||||
|
|
||||||
|
def _unflatten_index(self, flat_index):
|
||||||
|
indices = []
|
||||||
|
for b in reversed(self.bins):
|
||||||
|
indices.append(flat_index % b)
|
||||||
|
flat_index //= b
|
||||||
|
return list(reversed(indices))
|
||||||
|
@@ -9,6 +9,7 @@
|
|||||||
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
|
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
|
||||||
* ``DtypeObservation`` - Convert an observation to a dtype
|
* ``DtypeObservation`` - Convert an observation to a dtype
|
||||||
* ``RenderObservation`` - Allows the observation to the rendered frame
|
* ``RenderObservation`` - Allows the observation to the rendered frame
|
||||||
|
* ``DiscretizeObservation`` - Discretize a continuous Box observation space into a single Discrete space
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -34,6 +35,7 @@ __all__ = [
|
|||||||
"RescaleObservation",
|
"RescaleObservation",
|
||||||
"DtypeObservation",
|
"DtypeObservation",
|
||||||
"AddRenderObservation",
|
"AddRenderObservation",
|
||||||
|
"DiscretizeObservation",
|
||||||
]
|
]
|
||||||
|
|
||||||
from gymnasium.wrappers.utils import rescale_box
|
from gymnasium.wrappers.utils import rescale_box
|
||||||
@@ -682,3 +684,155 @@ class AddRenderObservation(
|
|||||||
func=lambda obs: {obs_key: obs, render_key: self.render()},
|
func=lambda obs: {obs_key: obs, render_key: self.render()},
|
||||||
observation_space=obs_space,
|
observation_space=obs_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscretizeObservation(
|
||||||
|
TransformObservation[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
|
"""Uniformly discretizes a continuous Box observation space into a single Discrete space.
|
||||||
|
|
||||||
|
Example 1 - Discretize MountainCar observation space:
|
||||||
|
>>> env = gym.make("MountainCar-v0")
|
||||||
|
>>> env.observation_space
|
||||||
|
Box([-1.2 -0.07], [0.6 0.07], (2,), float32)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs
|
||||||
|
array([-0.4452088, 0. ], dtype=float32)
|
||||||
|
>>> env = DiscretizeObservation(env, bins=10)
|
||||||
|
>>> env.observation_space
|
||||||
|
Discrete(100)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs
|
||||||
|
45
|
||||||
|
|
||||||
|
Example 2 - Discretize LunarLander observation space:
|
||||||
|
>>> env = gym.make("LunarLander-v3")
|
||||||
|
>>> env.observation_space
|
||||||
|
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
|
||||||
|
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
|
||||||
|
1. 1. ], (8,), float32)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs
|
||||||
|
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
|
||||||
|
-0.05269805, 0. , 0. ], dtype=float32)
|
||||||
|
>>> env = DiscretizeObservation(env, bins=3)
|
||||||
|
>>> env.observation_space
|
||||||
|
Discrete(6561)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs
|
||||||
|
4005
|
||||||
|
|
||||||
|
Example 3 - Discretize LunarLander observation space with MultiDiscrete:
|
||||||
|
>>> env = gym.make("LunarLander-v3")
|
||||||
|
>>> env.observation_space
|
||||||
|
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
|
||||||
|
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
|
||||||
|
1. 1. ], (8,), float32)
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs
|
||||||
|
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
|
||||||
|
-0.05269805, 0. , 0. ], dtype=float32)
|
||||||
|
>>> env = DiscretizeObservation(env, bins=3, multidiscrete=True)
|
||||||
|
>>> env.observation_space
|
||||||
|
MultiDiscrete([3 3 3 3 3 3 3 3])
|
||||||
|
>>> obs, _ = env.reset(seed=42)
|
||||||
|
>>> obs
|
||||||
|
array([1, 2, 1, 1, 1, 1, 0, 0])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env[ObsType, ActType],
|
||||||
|
bins: int | tuple[int, ...],
|
||||||
|
multidiscrete: bool = False,
|
||||||
|
):
|
||||||
|
"""Constructor for the discretize observation wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap.
|
||||||
|
bins: int or tuple of ints (number of bins per dimension).
|
||||||
|
multidiscrete: If True, use MultiDiscrete space instead of flattening to Discrete.
|
||||||
|
"""
|
||||||
|
if not isinstance(env.observation_space, spaces.Box):
|
||||||
|
raise TypeError(
|
||||||
|
"DiscretizeObservation is only compatible with Box continuous observations."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.low = env.observation_space.low
|
||||||
|
self.high = env.observation_space.high
|
||||||
|
self.n_dims = self.low.shape[0]
|
||||||
|
|
||||||
|
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
|
||||||
|
raise ValueError(
|
||||||
|
"Discretization requires observation space to be finite. "
|
||||||
|
f"Found: low={self.low}, high={self.high}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.multidiscrete = multidiscrete
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
|
if isinstance(bins, int):
|
||||||
|
self.bins = np.array([bins] * self.n_dims)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
len(bins) == self.n_dims
|
||||||
|
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
|
||||||
|
self.bins = np.array(bins)
|
||||||
|
|
||||||
|
self.bin_edges = [
|
||||||
|
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:-1]
|
||||||
|
for i in range(self.n_dims)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.multidiscrete:
|
||||||
|
self.observation_space = spaces.MultiDiscrete(self.bins)
|
||||||
|
else:
|
||||||
|
self.observation_space = spaces.Discrete(np.prod(self.bins))
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
"""Discretizes the observation."""
|
||||||
|
# np.digitize returns len(bins) if the input exceeds the last edge.
|
||||||
|
# If an observation is exactly equal to the high bound, the resulting
|
||||||
|
# index could be out of range for the number of bins.
|
||||||
|
# Solution: clip to ensure 0 <= index < bins[i], and add a small margin
|
||||||
|
# to prevent precision issues.
|
||||||
|
clipped = np.clip(observation, self.low, self.high - 1e-8)
|
||||||
|
indices = [
|
||||||
|
int(np.digitize(clipped[i], self.bin_edges[i])) for i in range(self.n_dims)
|
||||||
|
]
|
||||||
|
if self.multidiscrete:
|
||||||
|
return np.array(indices, dtype=np.int64)
|
||||||
|
else:
|
||||||
|
return int(self._flatten_indices(indices))
|
||||||
|
|
||||||
|
def revert_observation(self, obs):
|
||||||
|
"""Reverts discretization. It returns the edges of the bin the discretized observation belongs to."""
|
||||||
|
if self.multidiscrete:
|
||||||
|
indices = np.asarray(obs, dtype=int)
|
||||||
|
else:
|
||||||
|
indices = self._unflatten_index(obs)
|
||||||
|
lows = []
|
||||||
|
highs = []
|
||||||
|
for i, idx in enumerate(indices):
|
||||||
|
edges = np.linspace(self.low[i], self.high[i], self.bins[i] + 1)
|
||||||
|
lows.append(edges[idx])
|
||||||
|
highs.append(edges[idx + 1])
|
||||||
|
return np.array(lows, dtype=self.env.observation_space.dtype), np.array(
|
||||||
|
highs, dtype=self.env.observation_space.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def _flatten_indices(self, indices):
|
||||||
|
flat_index = 0
|
||||||
|
for i in range(self.n_dims):
|
||||||
|
flat_index *= self.bins[i]
|
||||||
|
flat_index += indices[i]
|
||||||
|
return flat_index
|
||||||
|
|
||||||
|
def _unflatten_index(self, flat_index):
|
||||||
|
indices = []
|
||||||
|
for b in reversed(self.bins):
|
||||||
|
indices.insert(0, flat_index % b)
|
||||||
|
flat_index //= b
|
||||||
|
return indices
|
||||||
|
56
tests/wrappers/test_discretize_action.py
Normal file
56
tests/wrappers/test_discretize_action.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Test suite for DiscretizeAction wrapper."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gymnasium.spaces import Box, Discrete
|
||||||
|
from gymnasium.wrappers import DiscretizeAction
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
|
||||||
|
def test_discretize_action_space_uniformity(dimensions):
|
||||||
|
"""Tests that the Box action space is discretized uniformly."""
|
||||||
|
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
|
||||||
|
n_bins = 7
|
||||||
|
env = DiscretizeAction(env, n_bins)
|
||||||
|
env_act = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
|
||||||
|
env_act = np.concatenate([o.flatten()[None] for o in env_act], 0).T
|
||||||
|
env_act_discretized = np.sort([env.revert_action(a) for a in env_act])
|
||||||
|
assert env_act.shape[0] == env.action_space.n
|
||||||
|
assert np.all(env_act_discretized == np.arange(env.action_space.n))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dimensions, bins, multidiscrete",
|
||||||
|
[
|
||||||
|
(1, 3, False),
|
||||||
|
(2, (3, 4), False),
|
||||||
|
(3, (3, 4, 5), False),
|
||||||
|
(1, 3, True),
|
||||||
|
(2, (3, 4), True),
|
||||||
|
(3, (3, 4, 5), True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_revert_discretize_action_space(dimensions, bins, multidiscrete):
|
||||||
|
"""Tests that the action is discretized correctly within the bins."""
|
||||||
|
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
|
||||||
|
env_discrete = DiscretizeAction(env, bins, multidiscrete)
|
||||||
|
for i in range(1000):
|
||||||
|
act_discrete = env_discrete.action_space.sample()
|
||||||
|
act_continuous = env_discrete.action(act_discrete)
|
||||||
|
assert env.action_space.contains(act_continuous)
|
||||||
|
assert np.all(env_discrete.revert_action(act_continuous) == act_discrete)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
|
||||||
|
def test_discretize_action_bounds(high, low):
|
||||||
|
"""Tests the discretize action wrapper with spaces that should raise an error."""
|
||||||
|
with pytest.raises((ValueError,)):
|
||||||
|
DiscretizeAction(GenericTestEnv(action_space=Box(low, high, shape=(1,))))
|
||||||
|
|
||||||
|
|
||||||
|
def test_discretize_action_dtype():
|
||||||
|
"""Tests the discretize action wrapper with spaces that should raise an error."""
|
||||||
|
with pytest.raises((TypeError,)):
|
||||||
|
DiscretizeAction(GenericTestEnv(action_space=Discrete(10)))
|
60
tests/wrappers/test_discretize_observation.py
Normal file
60
tests/wrappers/test_discretize_observation.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Test suite for DiscretizeObservation wrapper."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gymnasium.spaces import Box, Discrete
|
||||||
|
from gymnasium.wrappers import DiscretizeObservation
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
|
||||||
|
def test_discretize_observation_space_uniformity(dimensions):
|
||||||
|
"""Tests that the Box observation space is discretized uniformly."""
|
||||||
|
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
|
||||||
|
n_bins = 7
|
||||||
|
env = DiscretizeObservation(env, n_bins)
|
||||||
|
env_obs = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
|
||||||
|
env_obs = np.concatenate([o.flatten()[None] for o in env_obs], 0).T
|
||||||
|
env_obs_discretized = np.sort([env.observation(e) for e in env_obs])
|
||||||
|
assert env_obs.shape[0] == env.observation_space.n
|
||||||
|
assert np.all(env_obs_discretized == np.arange(env.observation_space.n))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dimensions, bins, multidiscrete",
|
||||||
|
[
|
||||||
|
(1, 3, False),
|
||||||
|
(2, (3, 4), False),
|
||||||
|
(3, (3, 4, 5), False),
|
||||||
|
(1, 3, True),
|
||||||
|
(2, (3, 4), True),
|
||||||
|
(3, (3, 4, 5), True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_revert_discretize_observation_space(dimensions, bins, multidiscrete):
|
||||||
|
"""Tests that the observation is discretized correctly within the bins."""
|
||||||
|
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
|
||||||
|
env_discrete = DiscretizeObservation(env, bins, multidiscrete)
|
||||||
|
for i in range(1000):
|
||||||
|
obs, _ = env.reset(seed=i)
|
||||||
|
obs_discrete, _ = env_discrete.reset(seed=i)
|
||||||
|
obs_reverted_min, obs_reverted_max = env_discrete.revert_observation(
|
||||||
|
obs_discrete,
|
||||||
|
)
|
||||||
|
assert np.all(obs >= obs_reverted_min) and np.all(obs <= obs_reverted_max)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
|
||||||
|
def test_discretize_observation_bounds(high, low):
|
||||||
|
"""Tests the discretize observation wrapper with spaces that should raise an error."""
|
||||||
|
with pytest.raises((ValueError,)):
|
||||||
|
DiscretizeObservation(
|
||||||
|
GenericTestEnv(observation_space=Box(low, high, shape=(1,)))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_discretize_observation_dtype():
|
||||||
|
"""Tests the discretize observation wrapper with spaces that should raise an error."""
|
||||||
|
with pytest.raises((TypeError,)):
|
||||||
|
DiscretizeObservation(GenericTestEnv(observation_space=Discrete(10)))
|
Reference in New Issue
Block a user