mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-16 11:39:13 +00:00
80 lines
2.1 KiB
Python
80 lines
2.1 KiB
Python
import re
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gymnasium.spaces import Text
|
|
|
|
|
|
def test_sample_mask():
|
|
space = Text(min_length=1, max_length=5)
|
|
|
|
# Test the sample length
|
|
sample = space.sample(mask=(3, None))
|
|
assert sample in space
|
|
assert len(sample) == 3
|
|
|
|
sample = space.sample(mask=None)
|
|
assert sample in space
|
|
assert 1 <= len(sample) <= 5
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match=re.escape(
|
|
"Trying to sample with a minimum length > 0 (actual minimum length=1) but the character mask is all zero meaning that no character could be sampled."
|
|
),
|
|
):
|
|
space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8)))
|
|
|
|
space = Text(min_length=0, max_length=5)
|
|
sample = space.sample(
|
|
mask=(None, np.zeros(len(space.character_set), dtype=np.int8))
|
|
)
|
|
assert sample in space
|
|
assert sample == ""
|
|
|
|
sample = space.sample(mask=(0, None))
|
|
assert sample in space
|
|
assert sample == ""
|
|
|
|
# Test the sample characters
|
|
space = Text(max_length=5, charset="abcd")
|
|
|
|
sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8)))
|
|
assert sample in space
|
|
assert sample == "bbb"
|
|
|
|
|
|
def test_sample_probability():
|
|
space = Text(min_length=1, max_length=5)
|
|
|
|
# Test the sample length
|
|
sample = space.sample(probability=(3, None))
|
|
assert sample in space
|
|
assert len(sample) == 3
|
|
|
|
sample = space.sample(probability=None)
|
|
assert sample in space
|
|
assert 1 <= len(sample) <= 5
|
|
|
|
with pytest.raises(
|
|
AssertionError,
|
|
match=re.escape(
|
|
"Expects the sum of the probability mask to be 1, actual sum: 0.0"
|
|
),
|
|
):
|
|
space.sample(
|
|
probability=(3, np.zeros(len(space.character_set), dtype=np.float64))
|
|
)
|
|
|
|
# Test the sample characters
|
|
space = Text(max_length=5, charset="abcd")
|
|
|
|
sample = space.sample(probability=(3, np.array([0, 1, 0, 0], dtype=np.float64)))
|
|
assert sample in space
|
|
assert sample == "bbb"
|
|
|
|
sample = space.sample(probability=(2, np.array([0.5, 0.5, 0, 0], dtype=np.float64)))
|
|
assert sample in space
|
|
assert sample in ["aa", "bb", "ab", "ba"]
|