mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
Add probability masking to space.sample
(#1310)
Co-authored-by: Mario Jerez <jerezmario1@gmail.com>
This commit is contained in:
@@ -105,3 +105,68 @@ def test_bad_seed():
|
||||
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
|
||||
):
|
||||
space.seed(0.0)
|
||||
|
||||
|
||||
def test_oneof_sample():
|
||||
"""Tests the sample method with and without masks or probabilities."""
|
||||
space = gym.spaces.Tuple([Discrete(2), Box(-1, 1, shape=(2,))])
|
||||
|
||||
# Unmasked sampling
|
||||
sample = space.sample()
|
||||
assert isinstance(sample, tuple)
|
||||
assert len(sample) == 2
|
||||
assert space.spaces[0].contains(sample[0])
|
||||
assert space.spaces[1].contains(sample[1])
|
||||
|
||||
# Masked sampling
|
||||
mask = (np.array([1, 0], dtype=np.int8), None)
|
||||
sample = space.sample(mask=mask)
|
||||
assert space.spaces[0].contains(sample[0])
|
||||
assert space.spaces[1].contains(sample[1])
|
||||
assert sample[0] == 0
|
||||
|
||||
# Probability sampling
|
||||
probability = (np.array([0.8, 0.2], dtype=np.float64), None)
|
||||
samples_discrete = np.array(
|
||||
[space.sample(probability=probability)[0] for _ in range(1000)]
|
||||
)
|
||||
counts = np.bincount(samples_discrete, minlength=2) / len(samples_discrete)
|
||||
np.testing.assert_allclose(counts, probability[0], atol=0.05)
|
||||
|
||||
|
||||
def test_invalid_sample_inputs():
|
||||
"""Tests that invalid inputs to sample raise appropriate errors."""
|
||||
space = gym.spaces.Tuple([Discrete(2), Box(-1, 1, shape=(2,))])
|
||||
|
||||
# Providing both mask and probability
|
||||
with pytest.raises(
|
||||
ValueError, match="Only one of `mask` or `probability` can be provided."
|
||||
):
|
||||
space.sample(mask=(None, None), probability=(0.5, 0.5))
|
||||
|
||||
# Invalid mask type
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Expected type of `mask` to be tuple, actual type: <class 'dict'>",
|
||||
):
|
||||
space.sample(mask={"low": 0, "high": 1})
|
||||
|
||||
# Invalid mask length
|
||||
with pytest.raises(
|
||||
AssertionError, match="Expected length of `mask` to be 2, actual length: 1"
|
||||
):
|
||||
space.sample(mask=(None,))
|
||||
|
||||
# Invalid probability length
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Expected length of `probability` to be 2, actual length: 1",
|
||||
):
|
||||
space.sample(probability=(0.5,))
|
||||
|
||||
# Invalid probability type
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Expected type of `probability` to be tuple, actual type: <class 'list'>",
|
||||
):
|
||||
space.sample(probability=[0.5, 0.5])
|
||||
|
Reference in New Issue
Block a user