Wrapper to discretize observations and actions (#1411)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
Simone Parisi
2025-07-03 13:05:55 -06:00
committed by GitHub
parent 7f11576fe6
commit 18da906de7
5 changed files with 420 additions and 1 deletions

View File

@@ -76,11 +76,13 @@ from gymnasium.wrappers.stateful_observation import (
from gymnasium.wrappers.stateful_reward import NormalizeReward from gymnasium.wrappers.stateful_reward import NormalizeReward
from gymnasium.wrappers.transform_action import ( from gymnasium.wrappers.transform_action import (
ClipAction, ClipAction,
DiscretizeAction,
RescaleAction, RescaleAction,
TransformAction, TransformAction,
) )
from gymnasium.wrappers.transform_observation import ( from gymnasium.wrappers.transform_observation import (
AddRenderObservation, AddRenderObservation,
DiscretizeObservation,
DtypeObservation, DtypeObservation,
FilterObservation, FilterObservation,
FlattenObservation, FlattenObservation,
@@ -99,6 +101,7 @@ __all__ = [
"AtariPreprocessing", "AtariPreprocessing",
"DelayObservation", "DelayObservation",
"DtypeObservation", "DtypeObservation",
"DiscretizeObservation",
"FilterObservation", "FilterObservation",
"FlattenObservation", "FlattenObservation",
"FrameStackObservation", "FrameStackObservation",
@@ -113,6 +116,7 @@ __all__ = [
"TimeAwareObservation", "TimeAwareObservation",
# --- Action Wrappers --- # --- Action Wrappers ---
"ClipAction", "ClipAction",
"DiscretizeAction",
"TransformAction", "TransformAction",
"RescaleAction", "RescaleAction",
# "NanAction", # "NanAction",

View File

@@ -2,6 +2,7 @@
* ``TransformAction`` - Transforms the actions based on a function * ``TransformAction`` - Transforms the actions based on a function
* ``ClipAction`` - Clips the action within a bounds * ``ClipAction`` - Clips the action within a bounds
* ``DiscretizeAction`` - Discretizes a continuous Box action space into a single Discrete space
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions * ``RescaleAction`` - Rescales the action within a minimum and maximum actions
""" """
@@ -13,7 +14,7 @@ import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.core import ActType, ObsType, WrapperActType from gymnasium.core import ActType, ObsType, WrapperActType
from gymnasium.spaces import Box, Space from gymnasium.spaces import Box, Discrete, MultiDiscrete, Space
__all__ = ["TransformAction", "ClipAction", "RescaleAction"] __all__ = ["TransformAction", "ClipAction", "RescaleAction"]
@@ -178,3 +179,147 @@ class RescaleAction(
func=func, func=func,
action_space=act_space, action_space=act_space,
) )
class DiscretizeAction(
TransformAction[ObsType, WrapperActType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Uniformly discretizes a continuous Box action space into a single Discrete space.
Example 1 - Discretize Pendulum action space:
>>> env = gym.make("Pendulum-v1")
>>> env.action_space
Box(-2.0, 2.0, (1,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([-0.6])
>>> obs
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
>>> env = DiscretizeAction(env, bins=10)
>>> env.action_space
Discrete(10)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step(3)
>>> obs
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
Example 2 - Discretize Reacher action space:
>>> env = gym.make("Reacher-v5")
>>> env.action_space
Box(-1.0, 1.0, (2,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([-0.3, -0.5])
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
>>> env = DiscretizeAction(env, bins=10)
>>> env.action_space
Discrete(100)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step(32)
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
Example 2 - Discretize Reacher action space with MultiDiscrete:
>>> env = gym.make("Reacher-v5")
>>> env.action_space
Box(-1.0, 1.0, (2,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([-0.3, -0.5])
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
>>> env = DiscretizeAction(env, bins=10, multidiscrete=True)
>>> env.action_space
MultiDiscrete([10 10])
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([3, 2])
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
bins: int | tuple[int, ...],
multidiscrete: bool = False,
):
"""Constructor for the discretize action wrapper.
Args:
env: The environment to wrap.
bins: int or tuple of ints (number of bins per dimension).
multidiscrete: If True, use MultiDiscrete action space instead of flattening to Discrete.
"""
if not isinstance(env.action_space, Box):
raise TypeError(
"DiscretizeAction is only compatible with Box continuous actions."
)
self.low = env.action_space.low
self.high = env.action_space.high
self.n_dims = self.low.shape[0]
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
raise ValueError(
"Discretization requires action space to be finite. "
f"Found: low={self.low}, high={self.high}"
)
self.multidiscrete = multidiscrete
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
gym.ActionWrapper.__init__(self, env)
if isinstance(bins, int):
self.bins = np.array([bins] * self.n_dims)
else:
assert (
len(bins) == self.n_dims
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
self.bins = np.array(bins)
self.bin_centers = [
0.5
* (
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[:-1]
+ np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:]
)
for i in range(self.n_dims)
]
if self.multidiscrete:
self.action_space = MultiDiscrete(self.bins)
else:
self.action_space = Discrete(np.prod(self.bins))
def action(self, act):
"""Discretizes the action."""
if self.multidiscrete:
indices = np.asarray(act, dtype=int)
else:
indices = self._unflatten_index(act)
centers = [
self.bin_centers[i][min(max(idx, 0), self.bins[i] - 1)]
for i, idx in enumerate(indices)
]
return np.array(centers, dtype=self.env.action_space.dtype)
def revert_action(self, action):
"""Converts a discretized action to a possible continuous action (the center of the closest bin)."""
indices = [
np.argmin(np.abs(self.bin_centers[i] - action[i]))
for i in range(self.n_dims)
]
if self.multidiscrete:
return np.array(indices, dtype=int)
else:
return np.ravel_multi_index(indices, self.bins)
def _unflatten_index(self, flat_index):
indices = []
for b in reversed(self.bins):
indices.append(flat_index % b)
flat_index //= b
return list(reversed(indices))

View File

@@ -9,6 +9,7 @@
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value * ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservation`` - Convert an observation to a dtype * ``DtypeObservation`` - Convert an observation to a dtype
* ``RenderObservation`` - Allows the observation to the rendered frame * ``RenderObservation`` - Allows the observation to the rendered frame
* ``DiscretizeObservation`` - Discretize a continuous Box observation space into a single Discrete space
""" """
from __future__ import annotations from __future__ import annotations
@@ -34,6 +35,7 @@ __all__ = [
"RescaleObservation", "RescaleObservation",
"DtypeObservation", "DtypeObservation",
"AddRenderObservation", "AddRenderObservation",
"DiscretizeObservation",
] ]
from gymnasium.wrappers.utils import rescale_box from gymnasium.wrappers.utils import rescale_box
@@ -682,3 +684,155 @@ class AddRenderObservation(
func=lambda obs: {obs_key: obs, render_key: self.render()}, func=lambda obs: {obs_key: obs, render_key: self.render()},
observation_space=obs_space, observation_space=obs_space,
) )
class DiscretizeObservation(
TransformObservation[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Uniformly discretizes a continuous Box observation space into a single Discrete space.
Example 1 - Discretize MountainCar observation space:
>>> env = gym.make("MountainCar-v0")
>>> env.observation_space
Box([-1.2 -0.07], [0.6 0.07], (2,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs
array([-0.4452088, 0. ], dtype=float32)
>>> env = DiscretizeObservation(env, bins=10)
>>> env.observation_space
Discrete(100)
>>> obs, _ = env.reset(seed=42)
>>> obs
45
Example 2 - Discretize LunarLander observation space:
>>> env = gym.make("LunarLander-v3")
>>> env.observation_space
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
1. 1. ], (8,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
-0.05269805, 0. , 0. ], dtype=float32)
>>> env = DiscretizeObservation(env, bins=3)
>>> env.observation_space
Discrete(6561)
>>> obs, _ = env.reset(seed=42)
>>> obs
4005
Example 3 - Discretize LunarLander observation space with MultiDiscrete:
>>> env = gym.make("LunarLander-v3")
>>> env.observation_space
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
1. 1. ], (8,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
-0.05269805, 0. , 0. ], dtype=float32)
>>> env = DiscretizeObservation(env, bins=3, multidiscrete=True)
>>> env.observation_space
MultiDiscrete([3 3 3 3 3 3 3 3])
>>> obs, _ = env.reset(seed=42)
>>> obs
array([1, 2, 1, 1, 1, 1, 0, 0])
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
bins: int | tuple[int, ...],
multidiscrete: bool = False,
):
"""Constructor for the discretize observation wrapper.
Args:
env: The environment to wrap.
bins: int or tuple of ints (number of bins per dimension).
multidiscrete: If True, use MultiDiscrete space instead of flattening to Discrete.
"""
if not isinstance(env.observation_space, spaces.Box):
raise TypeError(
"DiscretizeObservation is only compatible with Box continuous observations."
)
self.low = env.observation_space.low
self.high = env.observation_space.high
self.n_dims = self.low.shape[0]
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
raise ValueError(
"Discretization requires observation space to be finite. "
f"Found: low={self.low}, high={self.high}"
)
self.multidiscrete = multidiscrete
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
gym.ObservationWrapper.__init__(self, env)
if isinstance(bins, int):
self.bins = np.array([bins] * self.n_dims)
else:
assert (
len(bins) == self.n_dims
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
self.bins = np.array(bins)
self.bin_edges = [
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:-1]
for i in range(self.n_dims)
]
if self.multidiscrete:
self.observation_space = spaces.MultiDiscrete(self.bins)
else:
self.observation_space = spaces.Discrete(np.prod(self.bins))
def observation(self, observation):
"""Discretizes the observation."""
# np.digitize returns len(bins) if the input exceeds the last edge.
# If an observation is exactly equal to the high bound, the resulting
# index could be out of range for the number of bins.
# Solution: clip to ensure 0 <= index < bins[i], and add a small margin
# to prevent precision issues.
clipped = np.clip(observation, self.low, self.high - 1e-8)
indices = [
int(np.digitize(clipped[i], self.bin_edges[i])) for i in range(self.n_dims)
]
if self.multidiscrete:
return np.array(indices, dtype=np.int64)
else:
return int(self._flatten_indices(indices))
def revert_observation(self, obs):
"""Reverts discretization. It returns the edges of the bin the discretized observation belongs to."""
if self.multidiscrete:
indices = np.asarray(obs, dtype=int)
else:
indices = self._unflatten_index(obs)
lows = []
highs = []
for i, idx in enumerate(indices):
edges = np.linspace(self.low[i], self.high[i], self.bins[i] + 1)
lows.append(edges[idx])
highs.append(edges[idx + 1])
return np.array(lows, dtype=self.env.observation_space.dtype), np.array(
highs, dtype=self.env.observation_space.dtype
)
def _flatten_indices(self, indices):
flat_index = 0
for i in range(self.n_dims):
flat_index *= self.bins[i]
flat_index += indices[i]
return flat_index
def _unflatten_index(self, flat_index):
indices = []
for b in reversed(self.bins):
indices.insert(0, flat_index % b)
flat_index //= b
return indices

View File

@@ -0,0 +1,56 @@
"""Test suite for DiscretizeAction wrapper."""
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import DiscretizeAction
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
def test_discretize_action_space_uniformity(dimensions):
"""Tests that the Box action space is discretized uniformly."""
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
n_bins = 7
env = DiscretizeAction(env, n_bins)
env_act = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
env_act = np.concatenate([o.flatten()[None] for o in env_act], 0).T
env_act_discretized = np.sort([env.revert_action(a) for a in env_act])
assert env_act.shape[0] == env.action_space.n
assert np.all(env_act_discretized == np.arange(env.action_space.n))
@pytest.mark.parametrize(
"dimensions, bins, multidiscrete",
[
(1, 3, False),
(2, (3, 4), False),
(3, (3, 4, 5), False),
(1, 3, True),
(2, (3, 4), True),
(3, (3, 4, 5), True),
],
)
def test_revert_discretize_action_space(dimensions, bins, multidiscrete):
"""Tests that the action is discretized correctly within the bins."""
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
env_discrete = DiscretizeAction(env, bins, multidiscrete)
for i in range(1000):
act_discrete = env_discrete.action_space.sample()
act_continuous = env_discrete.action(act_discrete)
assert env.action_space.contains(act_continuous)
assert np.all(env_discrete.revert_action(act_continuous) == act_discrete)
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
def test_discretize_action_bounds(high, low):
"""Tests the discretize action wrapper with spaces that should raise an error."""
with pytest.raises((ValueError,)):
DiscretizeAction(GenericTestEnv(action_space=Box(low, high, shape=(1,))))
def test_discretize_action_dtype():
"""Tests the discretize action wrapper with spaces that should raise an error."""
with pytest.raises((TypeError,)):
DiscretizeAction(GenericTestEnv(action_space=Discrete(10)))

View File

@@ -0,0 +1,60 @@
"""Test suite for DiscretizeObservation wrapper."""
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import DiscretizeObservation
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
def test_discretize_observation_space_uniformity(dimensions):
"""Tests that the Box observation space is discretized uniformly."""
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
n_bins = 7
env = DiscretizeObservation(env, n_bins)
env_obs = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
env_obs = np.concatenate([o.flatten()[None] for o in env_obs], 0).T
env_obs_discretized = np.sort([env.observation(e) for e in env_obs])
assert env_obs.shape[0] == env.observation_space.n
assert np.all(env_obs_discretized == np.arange(env.observation_space.n))
@pytest.mark.parametrize(
"dimensions, bins, multidiscrete",
[
(1, 3, False),
(2, (3, 4), False),
(3, (3, 4, 5), False),
(1, 3, True),
(2, (3, 4), True),
(3, (3, 4, 5), True),
],
)
def test_revert_discretize_observation_space(dimensions, bins, multidiscrete):
"""Tests that the observation is discretized correctly within the bins."""
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
env_discrete = DiscretizeObservation(env, bins, multidiscrete)
for i in range(1000):
obs, _ = env.reset(seed=i)
obs_discrete, _ = env_discrete.reset(seed=i)
obs_reverted_min, obs_reverted_max = env_discrete.revert_observation(
obs_discrete,
)
assert np.all(obs >= obs_reverted_min) and np.all(obs <= obs_reverted_max)
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
def test_discretize_observation_bounds(high, low):
"""Tests the discretize observation wrapper with spaces that should raise an error."""
with pytest.raises((ValueError,)):
DiscretizeObservation(
GenericTestEnv(observation_space=Box(low, high, shape=(1,)))
)
def test_discretize_observation_dtype():
"""Tests the discretize observation wrapper with spaces that should raise an error."""
with pytest.raises((TypeError,)):
DiscretizeObservation(GenericTestEnv(observation_space=Discrete(10)))