2022-09-03 23:39:23 +01:00
|
|
|
import numpy as np
|
|
|
|
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium.spaces import MultiBinary
|
2022-09-03 23:39:23 +01:00
|
|
|
|
|
|
|
|
2022-09-03 22:56:29 +01:00
|
|
|
def test_sample():
|
2022-09-03 23:39:23 +01:00
|
|
|
space = MultiBinary(4)
|
|
|
|
|
|
|
|
sample = space.sample(mask=np.array([0, 0, 1, 1], dtype=np.int8))
|
|
|
|
assert np.all(sample == [0, 0, 1, 1])
|
|
|
|
|
|
|
|
sample = space.sample(mask=np.array([0, 1, 2, 2], dtype=np.int8))
|
|
|
|
assert sample[0] == 0 and sample[1] == 1
|
|
|
|
assert sample[2] == 0 or sample[2] == 1
|
|
|
|
assert sample[3] == 0 or sample[3] == 1
|
|
|
|
|
|
|
|
space = MultiBinary(np.array([2, 3]))
|
|
|
|
sample = space.sample(mask=np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int8))
|
|
|
|
assert np.all(sample == [[0, 0, 0], [1, 1, 1]]), sample
|
2025-02-21 13:39:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_sample_probabilities():
|
|
|
|
# Test sampling with probabilities
|
|
|
|
space = MultiBinary(4)
|
|
|
|
probabilities = np.array([0, 1, 0.5, 0.25], dtype=np.float64)
|
|
|
|
|
|
|
|
samples = [space.sample(probability=probabilities) for _ in range(10000)]
|
|
|
|
assert all(sample in space for sample in samples)
|
|
|
|
samples = np.array(samples)
|
|
|
|
|
|
|
|
# Check empirical probabilities
|
|
|
|
for i in range(4):
|
|
|
|
counts = np.sum(samples[:, i]) / len(samples)
|
|
|
|
np.testing.assert_allclose(counts, probabilities[i], atol=0.05)
|