Files
Gymnasium/tests/spaces/test_dict.py
2025-02-21 13:39:23 +00:00

316 lines
8.9 KiB
Python

import re
import warnings
from collections import OrderedDict
import numpy as np
import pytest
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils.env_checker import data_equivalence
def test_dict_init():
with pytest.raises(
TypeError,
match=r"^Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: ",
):
Dict(Discrete(2))
with pytest.raises(
ValueError,
match="Dict space keyword 'a' already exists in the spaces dictionary",
):
Dict({"a": Discrete(3)}, a=Box(0, 1))
with pytest.raises(
AssertionError,
match="Dict space element is not an instance of Space: key='b', space=Box",
):
Dict(a=Discrete(2), b="Box")
with warnings.catch_warnings(record=True) as caught_warnings:
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
assert a == b == c == d
assert len(caught_warnings) == 0
# test sorting
with warnings.catch_warnings(record=True) as caught_warnings:
# Sorting is applied to the keys
a = Dict({"b": Box(low=0.0, high=1.0), "a": Discrete(2)})
assert a.keys() == {"a", "b"}
# Sorting is not applied to the keys
b = Dict(OrderedDict(b=Box(low=0.0, high=1.0), a=Discrete(2)))
c = Dict((("b", Box(low=0.0, high=1.0)), ("a", Discrete(2))))
d = Dict(b=Box(low=0.0, high=1.0), a=Discrete(2))
assert b.keys() == c.keys() == d.keys() == {"b", "a"}
assert len(caught_warnings) == 0
# test sorting with different classes
with warnings.catch_warnings(record=True) as caught_warnings:
assert Dict({1: Discrete(2), "a": Discrete(3)}).keys() == {1, "a"}
assert len(caught_warnings) == 0
DICT_SPACE = Dict(
{
"a": Box(low=0, high=1, shape=(3, 3)),
"b": Dict(
{
"b_1": Box(low=-100, high=100, shape=(2,)),
"b_2": Box(low=-1, high=1, shape=(2,)),
}
),
"c": Discrete(5),
}
)
def test_dict_seeding():
seeding_values = {
"a": 0,
"b": {
"b_1": 1,
"b_2": 2,
},
"c": 3,
}
seeded_values = DICT_SPACE.seed(seeding_values)
assert data_equivalence(seeded_values, seeding_values)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=0)
b_1 = Box(low=-100, high=100, shape=(2,), seed=1)
b_2 = Box(low=-1, high=1, shape=(2,), seed=2)
c = Discrete(5, seed=3)
for i in range(10):
dict_sample = DICT_SPACE.sample()
assert np.all(dict_sample["a"] == a.sample())
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
assert dict_sample["c"] == c.sample()
def test_int_seeding():
seeds = DICT_SPACE.seed(1)
assert isinstance(seeds, dict)
# rng, seeds = seeding.np_random(1)
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
# b_rng, b_seeds = seeding.np_random(int(subseeds[1]))
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=seeds["a"])
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds["b"]["b_1"])
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds["b"]["b_2"])
c = Discrete(5, seed=seeds["c"])
for i in range(10):
dict_sample = DICT_SPACE.sample()
assert np.all(dict_sample["a"] == a.sample())
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
assert dict_sample["c"] == c.sample()
def test_none_seeding():
seeds = DICT_SPACE.seed(None)
assert isinstance(seeds, dict)
def test_bad_seed():
with pytest.raises(
TypeError,
match=re.escape(
"Expected seed type: dict, int or None, actual type: <class 'str'>"
),
):
DICT_SPACE.seed("a")
def test_mapping():
"""The Gym Dict space inherits from Mapping that allows it to appear like a standard python Dictionary."""
assert len(DICT_SPACE) == 3
a = DICT_SPACE["a"]
b = Discrete(5)
assert a != b
DICT_SPACE["a"] = b
assert DICT_SPACE["a"] == b
with pytest.raises(
AssertionError,
match="Trying to set a to Dict space with value that is not a gymnasium space, actual type: <class 'int'>",
):
DICT_SPACE["a"] = 5
DICT_SPACE["a"] = a
def test_iterator():
"""Tests the Dict `__iter__` function correctly returns keys in the subspaces"""
for key in DICT_SPACE:
assert key in DICT_SPACE.spaces
assert {key for key in DICT_SPACE} == DICT_SPACE.spaces.keys()
def test_keys_contains():
"""Test that `Dict.keys()` will correctly assess if the key is in the space."""
space = Dict(a=Box(0, 1), b=Box(1, 2))
for key in space.keys():
assert key in space.keys()
assert "a" in space.keys()
assert "c" not in space.keys()
def test_sample_with_mask():
"""Test the sample method with valid masks."""
space = Dict(
{
"a": Discrete(5),
"b": Box(low=0, high=1, shape=(2,)),
}
)
mask = {
"a": np.array(
[0, 1, 0, 0, 0], dtype=np.int8
), # Only allow sampling the value 1
"b": None, # No mask for Box space
}
for _ in range(10):
sample = space.sample(mask=mask)
assert sample["a"] == 1 # Discrete space should only return 1
assert space["b"].contains(sample["b"])
def test_sample_with_probability():
"""Test the sample method with valid probabilities."""
space = Dict(
{
"a": Discrete(3),
"b": Box(low=0, high=1, shape=(2,)),
}
)
probability = {
"a": np.array(
[0.1, 0.7, 0.2], dtype=np.float64
), # Sampling probabilities for Discrete space
"b": None, # No probability for Box space
}
samples = [space.sample(probability=probability)["a"] for _ in range(1000)]
# Check that the sampling roughly follows the probability distribution
counts = np.bincount(samples, minlength=3) / len(samples)
np.testing.assert_almost_equal(counts, probability["a"], decimal=1)
def test_sample_with_invalid_mask():
"""Test the sample method with an invalid mask."""
space = Dict(
{
"a": Discrete(5),
"b": Box(low=0, high=1, shape=(2,)),
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected shape of the sample mask is (5,), actual shape: (3,)"
),
):
space.sample(
mask={
"a": np.array([1, 0, 0], dtype=np.int8), # Length mismatch
"b": None,
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected dtype of the sample mask is np.int8, actual dtype: float32"
),
):
space.sample(
mask={
"a": np.array([1, 0, 0, 1, 1], dtype=np.float32), # dtype mismatch
"b": None,
}
)
def test_sample_with_invalid_probability():
"""Test the sample method with an invalid probability."""
space = Dict(
{
"a": Discrete(5),
"b": Box(low=0, high=1, shape=(2,)),
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected shape of the sample probability is (5,), actual shape: (2,)"
),
):
space.sample(
probability={
"a": np.array([0.5, 0.5], dtype=np.float64), # Length mismatch
"b": None,
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected dtype of the sample probability is np.float64, actual dtype: int8"
),
):
space.sample(
probability={
"a": np.array([0.5, 0.5], dtype=np.int8), # dtype mismatch
"b": None,
}
)
def test_sample_with_mask_and_probability():
"""Ensure an error is raised when both mask and probability are provided."""
space = Dict(
{
"a": Discrete(3),
"b": Box(low=0, high=1, shape=(2,)),
}
)
mask = {
"a": np.array([1, 0, 1], dtype=np.int8),
"b": None,
}
probability = {
"a": np.array([0.5, 0.2, 0.3], dtype=np.float64),
"b": None,
}
with pytest.raises(
ValueError,
match=re.escape("Only one of `mask` or `probability` can be provided"),
):
space.sample(mask=mask, probability=probability)