Add probability masking to space.sample (#1310)

Co-authored-by: Mario Jerez <jerezmario1@gmail.com>
This commit is contained in:
Mark Towers
2025-02-21 13:39:23 +00:00
committed by GitHub
parent 1dffcc6ed4
commit e4c1f901e9
21 changed files with 1053 additions and 182 deletions

View File

@@ -373,3 +373,15 @@ def test_sample_mask():
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
):
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))
def test_sample_probability_mask():
"""Box cannot have a probability mask applied."""
space = Box(0, 1)
with pytest.raises(
gym.error.Error,
match=re.escape(
"Box.sample cannot be provided a probability mask, actual value: "
),
):
space.sample(probability=np.array([0, 1, 0], dtype=np.float64))