Add probability masking to space.sample (#1310)

Co-authored-by: Mario Jerez <jerezmario1@gmail.com>
This commit is contained in:
Mark Towers
2025-02-21 13:39:23 +00:00
committed by GitHub
parent 1dffcc6ed4
commit e4c1f901e9
21 changed files with 1053 additions and 182 deletions

View File

@@ -105,3 +105,68 @@ def test_bad_seed():
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
):
space.seed(0.0)
def test_oneof_sample():
"""Tests the sample method with and without masks or probabilities."""
space = gym.spaces.Tuple([Discrete(2), Box(-1, 1, shape=(2,))])
# Unmasked sampling
sample = space.sample()
assert isinstance(sample, tuple)
assert len(sample) == 2
assert space.spaces[0].contains(sample[0])
assert space.spaces[1].contains(sample[1])
# Masked sampling
mask = (np.array([1, 0], dtype=np.int8), None)
sample = space.sample(mask=mask)
assert space.spaces[0].contains(sample[0])
assert space.spaces[1].contains(sample[1])
assert sample[0] == 0
# Probability sampling
probability = (np.array([0.8, 0.2], dtype=np.float64), None)
samples_discrete = np.array(
[space.sample(probability=probability)[0] for _ in range(1000)]
)
counts = np.bincount(samples_discrete, minlength=2) / len(samples_discrete)
np.testing.assert_allclose(counts, probability[0], atol=0.05)
def test_invalid_sample_inputs():
"""Tests that invalid inputs to sample raise appropriate errors."""
space = gym.spaces.Tuple([Discrete(2), Box(-1, 1, shape=(2,))])
# Providing both mask and probability
with pytest.raises(
ValueError, match="Only one of `mask` or `probability` can be provided."
):
space.sample(mask=(None, None), probability=(0.5, 0.5))
# Invalid mask type
with pytest.raises(
AssertionError,
match="Expected type of `mask` to be tuple, actual type: <class 'dict'>",
):
space.sample(mask={"low": 0, "high": 1})
# Invalid mask length
with pytest.raises(
AssertionError, match="Expected length of `mask` to be 2, actual length: 1"
):
space.sample(mask=(None,))
# Invalid probability length
with pytest.raises(
AssertionError,
match="Expected length of `probability` to be 2, actual length: 1",
):
space.sample(probability=(0.5,))
# Invalid probability type
with pytest.raises(
AssertionError,
match="Expected type of `probability` to be tuple, actual type: <class 'list'>",
):
space.sample(probability=[0.5, 0.5])