mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 06:16:32 +00:00
Wrapper to discretize observations and actions (#1411)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
@@ -76,11 +76,13 @@ from gymnasium.wrappers.stateful_observation import (
|
||||
from gymnasium.wrappers.stateful_reward import NormalizeReward
|
||||
from gymnasium.wrappers.transform_action import (
|
||||
ClipAction,
|
||||
DiscretizeAction,
|
||||
RescaleAction,
|
||||
TransformAction,
|
||||
)
|
||||
from gymnasium.wrappers.transform_observation import (
|
||||
AddRenderObservation,
|
||||
DiscretizeObservation,
|
||||
DtypeObservation,
|
||||
FilterObservation,
|
||||
FlattenObservation,
|
||||
@@ -99,6 +101,7 @@ __all__ = [
|
||||
"AtariPreprocessing",
|
||||
"DelayObservation",
|
||||
"DtypeObservation",
|
||||
"DiscretizeObservation",
|
||||
"FilterObservation",
|
||||
"FlattenObservation",
|
||||
"FrameStackObservation",
|
||||
@@ -113,6 +116,7 @@ __all__ = [
|
||||
"TimeAwareObservation",
|
||||
# --- Action Wrappers ---
|
||||
"ClipAction",
|
||||
"DiscretizeAction",
|
||||
"TransformAction",
|
||||
"RescaleAction",
|
||||
# "NanAction",
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
* ``TransformAction`` - Transforms the actions based on a function
|
||||
* ``ClipAction`` - Clips the action within a bounds
|
||||
* ``DiscretizeAction`` - Discretizes a continuous Box action space into a single Discrete space
|
||||
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
|
||||
"""
|
||||
|
||||
@@ -13,7 +14,7 @@ import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType, WrapperActType
|
||||
from gymnasium.spaces import Box, Space
|
||||
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Space
|
||||
|
||||
|
||||
__all__ = ["TransformAction", "ClipAction", "RescaleAction"]
|
||||
@@ -178,3 +179,147 @@ class RescaleAction(
|
||||
func=func,
|
||||
action_space=act_space,
|
||||
)
|
||||
|
||||
|
||||
class DiscretizeAction(
|
||||
TransformAction[ObsType, WrapperActType, ActType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Uniformly discretizes a continuous Box action space into a single Discrete space.
|
||||
|
||||
Example 1 - Discretize Pendulum action space:
|
||||
>>> env = gym.make("Pendulum-v1")
|
||||
>>> env.action_space
|
||||
Box(-2.0, 2.0, (1,), float32)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs, *_ = env.step([-0.6])
|
||||
>>> obs
|
||||
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
|
||||
>>> env = DiscretizeAction(env, bins=10)
|
||||
>>> env.action_space
|
||||
Discrete(10)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs, *_ = env.step(3)
|
||||
>>> obs
|
||||
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
|
||||
|
||||
Example 2 - Discretize Reacher action space:
|
||||
>>> env = gym.make("Reacher-v5")
|
||||
>>> env.action_space
|
||||
Box(-1.0, 1.0, (2,), float32)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs, *_ = env.step([-0.3, -0.5])
|
||||
>>> obs
|
||||
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
|
||||
>>> env = DiscretizeAction(env, bins=10)
|
||||
>>> env.action_space
|
||||
Discrete(100)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs, *_ = env.step(32)
|
||||
>>> obs
|
||||
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
|
||||
|
||||
Example 2 - Discretize Reacher action space with MultiDiscrete:
|
||||
>>> env = gym.make("Reacher-v5")
|
||||
>>> env.action_space
|
||||
Box(-1.0, 1.0, (2,), float32)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs, *_ = env.step([-0.3, -0.5])
|
||||
>>> obs
|
||||
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
|
||||
>>> env = DiscretizeAction(env, bins=10, multidiscrete=True)
|
||||
>>> env.action_space
|
||||
MultiDiscrete([10 10])
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs, *_ = env.step([3, 2])
|
||||
>>> obs
|
||||
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
|
||||
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
bins: int | tuple[int, ...],
|
||||
multidiscrete: bool = False,
|
||||
):
|
||||
"""Constructor for the discretize action wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap.
|
||||
bins: int or tuple of ints (number of bins per dimension).
|
||||
multidiscrete: If True, use MultiDiscrete action space instead of flattening to Discrete.
|
||||
"""
|
||||
if not isinstance(env.action_space, Box):
|
||||
raise TypeError(
|
||||
"DiscretizeAction is only compatible with Box continuous actions."
|
||||
)
|
||||
|
||||
self.low = env.action_space.low
|
||||
self.high = env.action_space.high
|
||||
self.n_dims = self.low.shape[0]
|
||||
|
||||
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
|
||||
raise ValueError(
|
||||
"Discretization requires action space to be finite. "
|
||||
f"Found: low={self.low}, high={self.high}"
|
||||
)
|
||||
|
||||
self.multidiscrete = multidiscrete
|
||||
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
|
||||
if isinstance(bins, int):
|
||||
self.bins = np.array([bins] * self.n_dims)
|
||||
else:
|
||||
assert (
|
||||
len(bins) == self.n_dims
|
||||
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
|
||||
self.bins = np.array(bins)
|
||||
|
||||
self.bin_centers = [
|
||||
0.5
|
||||
* (
|
||||
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[:-1]
|
||||
+ np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:]
|
||||
)
|
||||
for i in range(self.n_dims)
|
||||
]
|
||||
|
||||
if self.multidiscrete:
|
||||
self.action_space = MultiDiscrete(self.bins)
|
||||
else:
|
||||
self.action_space = Discrete(np.prod(self.bins))
|
||||
|
||||
def action(self, act):
|
||||
"""Discretizes the action."""
|
||||
if self.multidiscrete:
|
||||
indices = np.asarray(act, dtype=int)
|
||||
else:
|
||||
indices = self._unflatten_index(act)
|
||||
centers = [
|
||||
self.bin_centers[i][min(max(idx, 0), self.bins[i] - 1)]
|
||||
for i, idx in enumerate(indices)
|
||||
]
|
||||
return np.array(centers, dtype=self.env.action_space.dtype)
|
||||
|
||||
def revert_action(self, action):
|
||||
"""Converts a discretized action to a possible continuous action (the center of the closest bin)."""
|
||||
indices = [
|
||||
np.argmin(np.abs(self.bin_centers[i] - action[i]))
|
||||
for i in range(self.n_dims)
|
||||
]
|
||||
if self.multidiscrete:
|
||||
return np.array(indices, dtype=int)
|
||||
else:
|
||||
return np.ravel_multi_index(indices, self.bins)
|
||||
|
||||
def _unflatten_index(self, flat_index):
|
||||
indices = []
|
||||
for b in reversed(self.bins):
|
||||
indices.append(flat_index % b)
|
||||
flat_index //= b
|
||||
return list(reversed(indices))
|
||||
|
@@ -9,6 +9,7 @@
|
||||
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
|
||||
* ``DtypeObservation`` - Convert an observation to a dtype
|
||||
* ``RenderObservation`` - Allows the observation to the rendered frame
|
||||
* ``DiscretizeObservation`` - Discretize a continuous Box observation space into a single Discrete space
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -34,6 +35,7 @@ __all__ = [
|
||||
"RescaleObservation",
|
||||
"DtypeObservation",
|
||||
"AddRenderObservation",
|
||||
"DiscretizeObservation",
|
||||
]
|
||||
|
||||
from gymnasium.wrappers.utils import rescale_box
|
||||
@@ -682,3 +684,155 @@ class AddRenderObservation(
|
||||
func=lambda obs: {obs_key: obs, render_key: self.render()},
|
||||
observation_space=obs_space,
|
||||
)
|
||||
|
||||
|
||||
class DiscretizeObservation(
|
||||
TransformObservation[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Uniformly discretizes a continuous Box observation space into a single Discrete space.
|
||||
|
||||
Example 1 - Discretize MountainCar observation space:
|
||||
>>> env = gym.make("MountainCar-v0")
|
||||
>>> env.observation_space
|
||||
Box([-1.2 -0.07], [0.6 0.07], (2,), float32)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs
|
||||
array([-0.4452088, 0. ], dtype=float32)
|
||||
>>> env = DiscretizeObservation(env, bins=10)
|
||||
>>> env.observation_space
|
||||
Discrete(100)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs
|
||||
45
|
||||
|
||||
Example 2 - Discretize LunarLander observation space:
|
||||
>>> env = gym.make("LunarLander-v3")
|
||||
>>> env.observation_space
|
||||
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
|
||||
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
|
||||
1. 1. ], (8,), float32)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs
|
||||
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
|
||||
-0.05269805, 0. , 0. ], dtype=float32)
|
||||
>>> env = DiscretizeObservation(env, bins=3)
|
||||
>>> env.observation_space
|
||||
Discrete(6561)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs
|
||||
4005
|
||||
|
||||
Example 3 - Discretize LunarLander observation space with MultiDiscrete:
|
||||
>>> env = gym.make("LunarLander-v3")
|
||||
>>> env.observation_space
|
||||
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
|
||||
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
|
||||
1. 1. ], (8,), float32)
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs
|
||||
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
|
||||
-0.05269805, 0. , 0. ], dtype=float32)
|
||||
>>> env = DiscretizeObservation(env, bins=3, multidiscrete=True)
|
||||
>>> env.observation_space
|
||||
MultiDiscrete([3 3 3 3 3 3 3 3])
|
||||
>>> obs, _ = env.reset(seed=42)
|
||||
>>> obs
|
||||
array([1, 2, 1, 1, 1, 1, 0, 0])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
bins: int | tuple[int, ...],
|
||||
multidiscrete: bool = False,
|
||||
):
|
||||
"""Constructor for the discretize observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap.
|
||||
bins: int or tuple of ints (number of bins per dimension).
|
||||
multidiscrete: If True, use MultiDiscrete space instead of flattening to Discrete.
|
||||
"""
|
||||
if not isinstance(env.observation_space, spaces.Box):
|
||||
raise TypeError(
|
||||
"DiscretizeObservation is only compatible with Box continuous observations."
|
||||
)
|
||||
|
||||
self.low = env.observation_space.low
|
||||
self.high = env.observation_space.high
|
||||
self.n_dims = self.low.shape[0]
|
||||
|
||||
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
|
||||
raise ValueError(
|
||||
"Discretization requires observation space to be finite. "
|
||||
f"Found: low={self.low}, high={self.high}"
|
||||
)
|
||||
|
||||
self.multidiscrete = multidiscrete
|
||||
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
if isinstance(bins, int):
|
||||
self.bins = np.array([bins] * self.n_dims)
|
||||
else:
|
||||
assert (
|
||||
len(bins) == self.n_dims
|
||||
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
|
||||
self.bins = np.array(bins)
|
||||
|
||||
self.bin_edges = [
|
||||
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:-1]
|
||||
for i in range(self.n_dims)
|
||||
]
|
||||
|
||||
if self.multidiscrete:
|
||||
self.observation_space = spaces.MultiDiscrete(self.bins)
|
||||
else:
|
||||
self.observation_space = spaces.Discrete(np.prod(self.bins))
|
||||
|
||||
def observation(self, observation):
|
||||
"""Discretizes the observation."""
|
||||
# np.digitize returns len(bins) if the input exceeds the last edge.
|
||||
# If an observation is exactly equal to the high bound, the resulting
|
||||
# index could be out of range for the number of bins.
|
||||
# Solution: clip to ensure 0 <= index < bins[i], and add a small margin
|
||||
# to prevent precision issues.
|
||||
clipped = np.clip(observation, self.low, self.high - 1e-8)
|
||||
indices = [
|
||||
int(np.digitize(clipped[i], self.bin_edges[i])) for i in range(self.n_dims)
|
||||
]
|
||||
if self.multidiscrete:
|
||||
return np.array(indices, dtype=np.int64)
|
||||
else:
|
||||
return int(self._flatten_indices(indices))
|
||||
|
||||
def revert_observation(self, obs):
|
||||
"""Reverts discretization. It returns the edges of the bin the discretized observation belongs to."""
|
||||
if self.multidiscrete:
|
||||
indices = np.asarray(obs, dtype=int)
|
||||
else:
|
||||
indices = self._unflatten_index(obs)
|
||||
lows = []
|
||||
highs = []
|
||||
for i, idx in enumerate(indices):
|
||||
edges = np.linspace(self.low[i], self.high[i], self.bins[i] + 1)
|
||||
lows.append(edges[idx])
|
||||
highs.append(edges[idx + 1])
|
||||
return np.array(lows, dtype=self.env.observation_space.dtype), np.array(
|
||||
highs, dtype=self.env.observation_space.dtype
|
||||
)
|
||||
|
||||
def _flatten_indices(self, indices):
|
||||
flat_index = 0
|
||||
for i in range(self.n_dims):
|
||||
flat_index *= self.bins[i]
|
||||
flat_index += indices[i]
|
||||
return flat_index
|
||||
|
||||
def _unflatten_index(self, flat_index):
|
||||
indices = []
|
||||
for b in reversed(self.bins):
|
||||
indices.insert(0, flat_index % b)
|
||||
flat_index //= b
|
||||
return indices
|
||||
|
56
tests/wrappers/test_discretize_action.py
Normal file
56
tests/wrappers/test_discretize_action.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Test suite for DiscretizeAction wrapper."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.spaces import Box, Discrete
|
||||
from gymnasium.wrappers import DiscretizeAction
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
|
||||
def test_discretize_action_space_uniformity(dimensions):
|
||||
"""Tests that the Box action space is discretized uniformly."""
|
||||
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
|
||||
n_bins = 7
|
||||
env = DiscretizeAction(env, n_bins)
|
||||
env_act = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
|
||||
env_act = np.concatenate([o.flatten()[None] for o in env_act], 0).T
|
||||
env_act_discretized = np.sort([env.revert_action(a) for a in env_act])
|
||||
assert env_act.shape[0] == env.action_space.n
|
||||
assert np.all(env_act_discretized == np.arange(env.action_space.n))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dimensions, bins, multidiscrete",
|
||||
[
|
||||
(1, 3, False),
|
||||
(2, (3, 4), False),
|
||||
(3, (3, 4, 5), False),
|
||||
(1, 3, True),
|
||||
(2, (3, 4), True),
|
||||
(3, (3, 4, 5), True),
|
||||
],
|
||||
)
|
||||
def test_revert_discretize_action_space(dimensions, bins, multidiscrete):
|
||||
"""Tests that the action is discretized correctly within the bins."""
|
||||
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
|
||||
env_discrete = DiscretizeAction(env, bins, multidiscrete)
|
||||
for i in range(1000):
|
||||
act_discrete = env_discrete.action_space.sample()
|
||||
act_continuous = env_discrete.action(act_discrete)
|
||||
assert env.action_space.contains(act_continuous)
|
||||
assert np.all(env_discrete.revert_action(act_continuous) == act_discrete)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
|
||||
def test_discretize_action_bounds(high, low):
|
||||
"""Tests the discretize action wrapper with spaces that should raise an error."""
|
||||
with pytest.raises((ValueError,)):
|
||||
DiscretizeAction(GenericTestEnv(action_space=Box(low, high, shape=(1,))))
|
||||
|
||||
|
||||
def test_discretize_action_dtype():
|
||||
"""Tests the discretize action wrapper with spaces that should raise an error."""
|
||||
with pytest.raises((TypeError,)):
|
||||
DiscretizeAction(GenericTestEnv(action_space=Discrete(10)))
|
60
tests/wrappers/test_discretize_observation.py
Normal file
60
tests/wrappers/test_discretize_observation.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Test suite for DiscretizeObservation wrapper."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.spaces import Box, Discrete
|
||||
from gymnasium.wrappers import DiscretizeObservation
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
|
||||
def test_discretize_observation_space_uniformity(dimensions):
|
||||
"""Tests that the Box observation space is discretized uniformly."""
|
||||
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
|
||||
n_bins = 7
|
||||
env = DiscretizeObservation(env, n_bins)
|
||||
env_obs = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
|
||||
env_obs = np.concatenate([o.flatten()[None] for o in env_obs], 0).T
|
||||
env_obs_discretized = np.sort([env.observation(e) for e in env_obs])
|
||||
assert env_obs.shape[0] == env.observation_space.n
|
||||
assert np.all(env_obs_discretized == np.arange(env.observation_space.n))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dimensions, bins, multidiscrete",
|
||||
[
|
||||
(1, 3, False),
|
||||
(2, (3, 4), False),
|
||||
(3, (3, 4, 5), False),
|
||||
(1, 3, True),
|
||||
(2, (3, 4), True),
|
||||
(3, (3, 4, 5), True),
|
||||
],
|
||||
)
|
||||
def test_revert_discretize_observation_space(dimensions, bins, multidiscrete):
|
||||
"""Tests that the observation is discretized correctly within the bins."""
|
||||
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
|
||||
env_discrete = DiscretizeObservation(env, bins, multidiscrete)
|
||||
for i in range(1000):
|
||||
obs, _ = env.reset(seed=i)
|
||||
obs_discrete, _ = env_discrete.reset(seed=i)
|
||||
obs_reverted_min, obs_reverted_max = env_discrete.revert_observation(
|
||||
obs_discrete,
|
||||
)
|
||||
assert np.all(obs >= obs_reverted_min) and np.all(obs <= obs_reverted_max)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
|
||||
def test_discretize_observation_bounds(high, low):
|
||||
"""Tests the discretize observation wrapper with spaces that should raise an error."""
|
||||
with pytest.raises((ValueError,)):
|
||||
DiscretizeObservation(
|
||||
GenericTestEnv(observation_space=Box(low, high, shape=(1,)))
|
||||
)
|
||||
|
||||
|
||||
def test_discretize_observation_dtype():
|
||||
"""Tests the discretize observation wrapper with spaces that should raise an error."""
|
||||
with pytest.raises((TypeError,)):
|
||||
DiscretizeObservation(GenericTestEnv(observation_space=Discrete(10)))
|
Reference in New Issue
Block a user