Files
Gymnasium/tests/spaces/test_text.py
2025-02-21 13:39:23 +00:00

80 lines
2.1 KiB
Python

import re
import numpy as np
import pytest
from gymnasium.spaces import Text
def test_sample_mask():
space = Text(min_length=1, max_length=5)
# Test the sample length
sample = space.sample(mask=(3, None))
assert sample in space
assert len(sample) == 3
sample = space.sample(mask=None)
assert sample in space
assert 1 <= len(sample) <= 5
with pytest.raises(
ValueError,
match=re.escape(
"Trying to sample with a minimum length > 0 (actual minimum length=1) but the character mask is all zero meaning that no character could be sampled."
),
):
space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8)))
space = Text(min_length=0, max_length=5)
sample = space.sample(
mask=(None, np.zeros(len(space.character_set), dtype=np.int8))
)
assert sample in space
assert sample == ""
sample = space.sample(mask=(0, None))
assert sample in space
assert sample == ""
# Test the sample characters
space = Text(max_length=5, charset="abcd")
sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8)))
assert sample in space
assert sample == "bbb"
def test_sample_probability():
space = Text(min_length=1, max_length=5)
# Test the sample length
sample = space.sample(probability=(3, None))
assert sample in space
assert len(sample) == 3
sample = space.sample(probability=None)
assert sample in space
assert 1 <= len(sample) <= 5
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the sum of the probability mask to be 1, actual sum: 0.0"
),
):
space.sample(
probability=(3, np.zeros(len(space.character_set), dtype=np.float64))
)
# Test the sample characters
space = Text(max_length=5, charset="abcd")
sample = space.sample(probability=(3, np.array([0, 1, 0, 0], dtype=np.float64)))
assert sample in space
assert sample == "bbb"
sample = space.sample(probability=(2, np.array([0.5, 0.5, 0, 0], dtype=np.float64)))
assert sample in space
assert sample in ["aa", "bb", "ab", "ba"]