Files
Gymnasium/tests/spaces/test_tuple.py

110 lines
3.0 KiB
Python

import numpy as np
import pytest
import gymnasium as gym
from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, Tuple
from gymnasium.utils.env_checker import data_equivalence
def test_sequence_inheritance():
"""The gymnasium Tuple space inherits from abc.Sequences, this test checks all functions work"""
spaces = [Discrete(5), Discrete(10), Discrete(5)]
tuple_space = Tuple(spaces)
assert len(tuple_space) == len(spaces)
# Test indexing
for i in range(len(tuple_space)):
assert tuple_space[i] == spaces[i]
# Test iterable
for space in tuple_space:
assert space in spaces
# Test count
assert tuple_space.count(Discrete(5)) == 2
assert tuple_space.count(Discrete(6)) == 0
assert tuple_space.count(MultiBinary(2)) == 0
# Test index
assert tuple_space.index(Discrete(5)) == 0
assert tuple_space.index(Discrete(5), 1) == 2
# Test errors
with pytest.raises(ValueError):
tuple_space.index(Discrete(10), 0, 1)
with pytest.raises(IndexError):
assert tuple_space[4]
@pytest.mark.parametrize(
"space, seed, expected_len",
[
(Tuple([Discrete(5), Discrete(4)]), None, 2),
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
(
Tuple(
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
),
(123, (456, 789)),
3,
),
(
Tuple(
(
Discrete(3),
Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)),
)
),
(123, {"position": 456, "velocity": 789}),
3,
),
],
)
def test_seeds(space, seed, expected_len):
seeds = space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
sample1 = space.sample()
seeds2 = space.seed(seed)
sample2 = space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
@pytest.mark.parametrize(
"space_fn",
[
lambda: Tuple(["abc"]),
lambda: Tuple([gym.spaces.Box(0, 1), "abc"]),
lambda: Tuple("abc"),
],
)
def test_bad_space_calls(space_fn):
with pytest.raises(AssertionError):
space_fn()
def test_contains_promotion():
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (2,))))
assert (
np.array([0.0], dtype=np.float32),
np.array([0.0, 0.0], dtype=np.float32),
) in space
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (1,))))
assert np.array([[0.0], [0.0]], dtype=np.float32) in space
def test_bad_seed():
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(0, 1)))
with pytest.raises(
TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
):
space.seed(0.0)