Wrapper to discretize observations and actions (#1411)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
Simone Parisi
2025-07-03 13:05:55 -06:00
committed by GitHub
parent 7f11576fe6
commit 18da906de7
5 changed files with 420 additions and 1 deletions

View File

@@ -76,11 +76,13 @@ from gymnasium.wrappers.stateful_observation import (
from gymnasium.wrappers.stateful_reward import NormalizeReward
from gymnasium.wrappers.transform_action import (
ClipAction,
DiscretizeAction,
RescaleAction,
TransformAction,
)
from gymnasium.wrappers.transform_observation import (
AddRenderObservation,
DiscretizeObservation,
DtypeObservation,
FilterObservation,
FlattenObservation,
@@ -99,6 +101,7 @@ __all__ = [
"AtariPreprocessing",
"DelayObservation",
"DtypeObservation",
"DiscretizeObservation",
"FilterObservation",
"FlattenObservation",
"FrameStackObservation",
@@ -113,6 +116,7 @@ __all__ = [
"TimeAwareObservation",
# --- Action Wrappers ---
"ClipAction",
"DiscretizeAction",
"TransformAction",
"RescaleAction",
# "NanAction",

View File

@@ -2,6 +2,7 @@
* ``TransformAction`` - Transforms the actions based on a function
* ``ClipAction`` - Clips the action within a bounds
* ``DiscretizeAction`` - Discretizes a continuous Box action space into a single Discrete space
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
"""
@@ -13,7 +14,7 @@ import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType, WrapperActType
from gymnasium.spaces import Box, Space
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Space
__all__ = ["TransformAction", "ClipAction", "RescaleAction"]
@@ -178,3 +179,147 @@ class RescaleAction(
func=func,
action_space=act_space,
)
class DiscretizeAction(
TransformAction[ObsType, WrapperActType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Uniformly discretizes a continuous Box action space into a single Discrete space.
Example 1 - Discretize Pendulum action space:
>>> env = gym.make("Pendulum-v1")
>>> env.action_space
Box(-2.0, 2.0, (1,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([-0.6])
>>> obs
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
>>> env = DiscretizeAction(env, bins=10)
>>> env.action_space
Discrete(10)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step(3)
>>> obs
array([-0.17606162, 0.9843792 , 0.5292768 ], dtype=float32)
Example 2 - Discretize Reacher action space:
>>> env = gym.make("Reacher-v5")
>>> env.action_space
Box(-1.0, 1.0, (2,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([-0.3, -0.5])
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
>>> env = DiscretizeAction(env, bins=10)
>>> env.action_space
Discrete(100)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step(32)
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
Example 2 - Discretize Reacher action space with MultiDiscrete:
>>> env = gym.make("Reacher-v5")
>>> env.action_space
Box(-1.0, 1.0, (2,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([-0.3, -0.5])
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958125, -1.97979484, 0.1054461 , -0.10896341])
>>> env = DiscretizeAction(env, bins=10, multidiscrete=True)
>>> env.action_space
MultiDiscrete([10 10])
>>> obs, _ = env.reset(seed=42)
>>> obs, *_ = env.step([3, 2])
>>> obs
array([ 0.99908342, 0.99948506, 0.04280567, -0.03208766, 0.10445588,
0.11442572, -1.18958118, -1.97979484, 0.1054461 , -0.10896341])
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
bins: int | tuple[int, ...],
multidiscrete: bool = False,
):
"""Constructor for the discretize action wrapper.
Args:
env: The environment to wrap.
bins: int or tuple of ints (number of bins per dimension).
multidiscrete: If True, use MultiDiscrete action space instead of flattening to Discrete.
"""
if not isinstance(env.action_space, Box):
raise TypeError(
"DiscretizeAction is only compatible with Box continuous actions."
)
self.low = env.action_space.low
self.high = env.action_space.high
self.n_dims = self.low.shape[0]
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
raise ValueError(
"Discretization requires action space to be finite. "
f"Found: low={self.low}, high={self.high}"
)
self.multidiscrete = multidiscrete
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
gym.ActionWrapper.__init__(self, env)
if isinstance(bins, int):
self.bins = np.array([bins] * self.n_dims)
else:
assert (
len(bins) == self.n_dims
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
self.bins = np.array(bins)
self.bin_centers = [
0.5
* (
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[:-1]
+ np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:]
)
for i in range(self.n_dims)
]
if self.multidiscrete:
self.action_space = MultiDiscrete(self.bins)
else:
self.action_space = Discrete(np.prod(self.bins))
def action(self, act):
"""Discretizes the action."""
if self.multidiscrete:
indices = np.asarray(act, dtype=int)
else:
indices = self._unflatten_index(act)
centers = [
self.bin_centers[i][min(max(idx, 0), self.bins[i] - 1)]
for i, idx in enumerate(indices)
]
return np.array(centers, dtype=self.env.action_space.dtype)
def revert_action(self, action):
"""Converts a discretized action to a possible continuous action (the center of the closest bin)."""
indices = [
np.argmin(np.abs(self.bin_centers[i] - action[i]))
for i in range(self.n_dims)
]
if self.multidiscrete:
return np.array(indices, dtype=int)
else:
return np.ravel_multi_index(indices, self.bins)
def _unflatten_index(self, flat_index):
indices = []
for b in reversed(self.bins):
indices.append(flat_index % b)
flat_index //= b
return list(reversed(indices))

View File

@@ -9,6 +9,7 @@
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservation`` - Convert an observation to a dtype
* ``RenderObservation`` - Allows the observation to the rendered frame
* ``DiscretizeObservation`` - Discretize a continuous Box observation space into a single Discrete space
"""
from __future__ import annotations
@@ -34,6 +35,7 @@ __all__ = [
"RescaleObservation",
"DtypeObservation",
"AddRenderObservation",
"DiscretizeObservation",
]
from gymnasium.wrappers.utils import rescale_box
@@ -682,3 +684,155 @@ class AddRenderObservation(
func=lambda obs: {obs_key: obs, render_key: self.render()},
observation_space=obs_space,
)
class DiscretizeObservation(
TransformObservation[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Uniformly discretizes a continuous Box observation space into a single Discrete space.
Example 1 - Discretize MountainCar observation space:
>>> env = gym.make("MountainCar-v0")
>>> env.observation_space
Box([-1.2 -0.07], [0.6 0.07], (2,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs
array([-0.4452088, 0. ], dtype=float32)
>>> env = DiscretizeObservation(env, bins=10)
>>> env.observation_space
Discrete(100)
>>> obs, _ = env.reset(seed=42)
>>> obs
45
Example 2 - Discretize LunarLander observation space:
>>> env = gym.make("LunarLander-v3")
>>> env.observation_space
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
1. 1. ], (8,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
-0.05269805, 0. , 0. ], dtype=float32)
>>> env = DiscretizeObservation(env, bins=3)
>>> env.observation_space
Discrete(6561)
>>> obs, _ = env.reset(seed=42)
>>> obs
4005
Example 3 - Discretize LunarLander observation space with MultiDiscrete:
>>> env = gym.make("LunarLander-v3")
>>> env.observation_space
Box([ -2.5 -2.5 -10. -10. -6.2831855 -10.
-0. -0. ], [ 2.5 2.5 10. 10. 6.2831855 10.
1. 1. ], (8,), float32)
>>> obs, _ = env.reset(seed=42)
>>> obs
array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,
-0.05269805, 0. , 0. ], dtype=float32)
>>> env = DiscretizeObservation(env, bins=3, multidiscrete=True)
>>> env.observation_space
MultiDiscrete([3 3 3 3 3 3 3 3])
>>> obs, _ = env.reset(seed=42)
>>> obs
array([1, 2, 1, 1, 1, 1, 0, 0])
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
bins: int | tuple[int, ...],
multidiscrete: bool = False,
):
"""Constructor for the discretize observation wrapper.
Args:
env: The environment to wrap.
bins: int or tuple of ints (number of bins per dimension).
multidiscrete: If True, use MultiDiscrete space instead of flattening to Discrete.
"""
if not isinstance(env.observation_space, spaces.Box):
raise TypeError(
"DiscretizeObservation is only compatible with Box continuous observations."
)
self.low = env.observation_space.low
self.high = env.observation_space.high
self.n_dims = self.low.shape[0]
if np.any(np.isinf(self.low)) or np.any(np.isinf(self.high)):
raise ValueError(
"Discretization requires observation space to be finite. "
f"Found: low={self.low}, high={self.high}"
)
self.multidiscrete = multidiscrete
gym.utils.RecordConstructorArgs.__init__(self, bins=bins)
gym.ObservationWrapper.__init__(self, env)
if isinstance(bins, int):
self.bins = np.array([bins] * self.n_dims)
else:
assert (
len(bins) == self.n_dims
), f"bins must match action dimensions: expected {self.n_dims}, got {len(bins)}"
self.bins = np.array(bins)
self.bin_edges = [
np.linspace(self.low[i], self.high[i], self.bins[i] + 1)[1:-1]
for i in range(self.n_dims)
]
if self.multidiscrete:
self.observation_space = spaces.MultiDiscrete(self.bins)
else:
self.observation_space = spaces.Discrete(np.prod(self.bins))
def observation(self, observation):
"""Discretizes the observation."""
# np.digitize returns len(bins) if the input exceeds the last edge.
# If an observation is exactly equal to the high bound, the resulting
# index could be out of range for the number of bins.
# Solution: clip to ensure 0 <= index < bins[i], and add a small margin
# to prevent precision issues.
clipped = np.clip(observation, self.low, self.high - 1e-8)
indices = [
int(np.digitize(clipped[i], self.bin_edges[i])) for i in range(self.n_dims)
]
if self.multidiscrete:
return np.array(indices, dtype=np.int64)
else:
return int(self._flatten_indices(indices))
def revert_observation(self, obs):
"""Reverts discretization. It returns the edges of the bin the discretized observation belongs to."""
if self.multidiscrete:
indices = np.asarray(obs, dtype=int)
else:
indices = self._unflatten_index(obs)
lows = []
highs = []
for i, idx in enumerate(indices):
edges = np.linspace(self.low[i], self.high[i], self.bins[i] + 1)
lows.append(edges[idx])
highs.append(edges[idx + 1])
return np.array(lows, dtype=self.env.observation_space.dtype), np.array(
highs, dtype=self.env.observation_space.dtype
)
def _flatten_indices(self, indices):
flat_index = 0
for i in range(self.n_dims):
flat_index *= self.bins[i]
flat_index += indices[i]
return flat_index
def _unflatten_index(self, flat_index):
indices = []
for b in reversed(self.bins):
indices.insert(0, flat_index % b)
flat_index //= b
return indices

View File

@@ -0,0 +1,56 @@
"""Test suite for DiscretizeAction wrapper."""
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import DiscretizeAction
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
def test_discretize_action_space_uniformity(dimensions):
"""Tests that the Box action space is discretized uniformly."""
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
n_bins = 7
env = DiscretizeAction(env, n_bins)
env_act = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
env_act = np.concatenate([o.flatten()[None] for o in env_act], 0).T
env_act_discretized = np.sort([env.revert_action(a) for a in env_act])
assert env_act.shape[0] == env.action_space.n
assert np.all(env_act_discretized == np.arange(env.action_space.n))
@pytest.mark.parametrize(
"dimensions, bins, multidiscrete",
[
(1, 3, False),
(2, (3, 4), False),
(3, (3, 4, 5), False),
(1, 3, True),
(2, (3, 4), True),
(3, (3, 4, 5), True),
],
)
def test_revert_discretize_action_space(dimensions, bins, multidiscrete):
"""Tests that the action is discretized correctly within the bins."""
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
env_discrete = DiscretizeAction(env, bins, multidiscrete)
for i in range(1000):
act_discrete = env_discrete.action_space.sample()
act_continuous = env_discrete.action(act_discrete)
assert env.action_space.contains(act_continuous)
assert np.all(env_discrete.revert_action(act_continuous) == act_discrete)
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
def test_discretize_action_bounds(high, low):
"""Tests the discretize action wrapper with spaces that should raise an error."""
with pytest.raises((ValueError,)):
DiscretizeAction(GenericTestEnv(action_space=Box(low, high, shape=(1,))))
def test_discretize_action_dtype():
"""Tests the discretize action wrapper with spaces that should raise an error."""
with pytest.raises((TypeError,)):
DiscretizeAction(GenericTestEnv(action_space=Discrete(10)))

View File

@@ -0,0 +1,60 @@
"""Test suite for DiscretizeObservation wrapper."""
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import DiscretizeObservation
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
def test_discretize_observation_space_uniformity(dimensions):
"""Tests that the Box observation space is discretized uniformly."""
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
n_bins = 7
env = DiscretizeObservation(env, n_bins)
env_obs = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
env_obs = np.concatenate([o.flatten()[None] for o in env_obs], 0).T
env_obs_discretized = np.sort([env.observation(e) for e in env_obs])
assert env_obs.shape[0] == env.observation_space.n
assert np.all(env_obs_discretized == np.arange(env.observation_space.n))
@pytest.mark.parametrize(
"dimensions, bins, multidiscrete",
[
(1, 3, False),
(2, (3, 4), False),
(3, (3, 4, 5), False),
(1, 3, True),
(2, (3, 4), True),
(3, (3, 4, 5), True),
],
)
def test_revert_discretize_observation_space(dimensions, bins, multidiscrete):
"""Tests that the observation is discretized correctly within the bins."""
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
env_discrete = DiscretizeObservation(env, bins, multidiscrete)
for i in range(1000):
obs, _ = env.reset(seed=i)
obs_discrete, _ = env_discrete.reset(seed=i)
obs_reverted_min, obs_reverted_max = env_discrete.revert_observation(
obs_discrete,
)
assert np.all(obs >= obs_reverted_min) and np.all(obs <= obs_reverted_max)
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
def test_discretize_observation_bounds(high, low):
"""Tests the discretize observation wrapper with spaces that should raise an error."""
with pytest.raises((ValueError,)):
DiscretizeObservation(
GenericTestEnv(observation_space=Box(low, high, shape=(1,)))
)
def test_discretize_observation_dtype():
"""Tests the discretize observation wrapper with spaces that should raise an error."""
with pytest.raises((TypeError,)):
DiscretizeObservation(GenericTestEnv(observation_space=Discrete(10)))