mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 14:26:33 +00:00
Fixed Box.sample
bug for up-bounded discrete or boolean dtypes (#249)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
4dd526d370
commit
76559ed553
@@ -211,13 +211,14 @@ class Box(Space[NDArray[Any]]):
|
||||
|
||||
sample[upp_bounded] = (
|
||||
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
|
||||
+ self.high[upp_bounded]
|
||||
+ high[upp_bounded]
|
||||
)
|
||||
|
||||
sample[bounded] = self.np_random.uniform(
|
||||
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
|
||||
)
|
||||
if self.dtype.kind == "i":
|
||||
|
||||
if self.dtype.kind in ["i", "u", "b"]:
|
||||
sample = np.floor(sample)
|
||||
|
||||
return sample.astype(self.dtype)
|
||||
|
@@ -82,7 +82,10 @@ all = [
|
||||
"moviepy >=1.0.0",
|
||||
"torch >=1.0.0",
|
||||
]
|
||||
testing = ["pytest ==7.1.3"]
|
||||
testing = [
|
||||
"pytest ==7.1.3",
|
||||
"scipy ==1.7.3",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://farama.org"
|
||||
|
@@ -7,6 +7,7 @@ from typing import Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
|
||||
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
|
||||
from gymnasium.utils import seeding
|
||||
@@ -68,22 +69,8 @@ def test_space_equality(space_1, space_2):
|
||||
assert space_1 != space_2
|
||||
|
||||
|
||||
# The expected sum of variance for an alpha of 0.05
|
||||
# CHI_SQUARED = [0] + [scipy.stats.chi2.isf(0.05, df=df) for df in range(1, 25)]
|
||||
CHI_SQUARED = np.array(
|
||||
[
|
||||
0.01,
|
||||
3.8414588206941285,
|
||||
5.991464547107983,
|
||||
7.814727903251178,
|
||||
9.487729036781158,
|
||||
11.070497693516355,
|
||||
12.59158724374398,
|
||||
14.067140449340167,
|
||||
15.507313055865454,
|
||||
16.91897760462045,
|
||||
]
|
||||
)
|
||||
# significance level of chi2 and KS tests
|
||||
ALPHA = 0.05
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -105,8 +92,31 @@ def test_sample(space: Space, n_trials: int = 1_000):
|
||||
assert len(samples) == n_trials
|
||||
|
||||
if isinstance(space, Box):
|
||||
# TODO: Add KS testing for continuous uniform distribution
|
||||
pass
|
||||
if space.dtype.kind == "f":
|
||||
test_function = ks_test
|
||||
elif space.dtype.kind in ["i", "u"]:
|
||||
test_function = chi2_test
|
||||
elif space.dtype.kind == "b":
|
||||
test_function = binary_chi2_test
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown test for Box(dtype={space.dtype})")
|
||||
|
||||
assert space.shape == space.low.shape == space.high.shape
|
||||
assert space.shape == samples.shape[1:]
|
||||
|
||||
# (n_trials, *space.shape) => (*space.shape, n_trials)
|
||||
samples = np.moveaxis(samples, 0, -1)
|
||||
|
||||
for index in np.ndindex(space.shape):
|
||||
low = space.low[index]
|
||||
high = space.high[index]
|
||||
sample = samples[index]
|
||||
|
||||
bounded_below = space.bounded_below[index]
|
||||
bounded_above = space.bounded_above[index]
|
||||
|
||||
test_function(sample, low, high, bounded_below, bounded_above)
|
||||
|
||||
elif isinstance(space, Discrete):
|
||||
expected_frequency = np.ones(space.n) * n_trials / space.n
|
||||
observed_frequency = np.zeros(space.n)
|
||||
@@ -120,7 +130,7 @@ def test_sample(space: Space, n_trials: int = 1_000):
|
||||
variance = np.sum(
|
||||
np.square(expected_frequency - observed_frequency) / expected_frequency
|
||||
)
|
||||
assert variance < CHI_SQUARED[degrees_of_freedom]
|
||||
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
|
||||
elif isinstance(space, MultiBinary):
|
||||
expected_frequency = n_trials / 2
|
||||
observed_frequency = np.sum(samples, axis=0)
|
||||
@@ -131,7 +141,7 @@ def test_sample(space: Space, n_trials: int = 1_000):
|
||||
2 * np.square(observed_frequency - expected_frequency) / expected_frequency
|
||||
)
|
||||
assert variance.shape == space.shape
|
||||
assert np.all(variance < CHI_SQUARED[1])
|
||||
assert np.all(variance < scipy.stats.chi2.isf(ALPHA, df=1))
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
|
||||
def _generate_frequency(dim, func):
|
||||
@@ -167,7 +177,7 @@ def test_sample(space: Space, n_trials: int = 1_000):
|
||||
assert np.sum(exp_freq) == n_trials
|
||||
_variance = np.sum(np.square(exp_freq - obs_freq) / exp_freq)
|
||||
_degrees_of_freedom = dim - 1
|
||||
assert _variance < CHI_SQUARED[_degrees_of_freedom]
|
||||
assert _variance < scipy.stats.chi2.isf(ALPHA, df=_degrees_of_freedom)
|
||||
|
||||
_chi_squared_test(space.nvec, expected_frequency, observed_frequency)
|
||||
elif isinstance(space, Text):
|
||||
@@ -189,15 +199,112 @@ def test_sample(space: Space, n_trials: int = 1_000):
|
||||
variance = np.sum(
|
||||
np.square(expected_frequency - observed_frequency) / expected_frequency
|
||||
)
|
||||
if degrees_of_freedom == 61:
|
||||
# scipy.stats.chi2.isf(0.05, df=61)
|
||||
assert variance < 80.23209784876272
|
||||
else:
|
||||
assert variance < CHI_SQUARED[degrees_of_freedom]
|
||||
|
||||
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown sample testing for {type(space)}")
|
||||
|
||||
|
||||
def ks_test(sample, low, high, bounded_below, bounded_above):
|
||||
"""Perform Kolmogorov-Smirnov test on the sample. Automatically picks the
|
||||
distribution to test against based on the bounds.
|
||||
"""
|
||||
if bounded_below and bounded_above:
|
||||
# X ~ U(low, high)
|
||||
dist = scipy.stats.uniform(low, high - low)
|
||||
elif bounded_below and not bounded_above:
|
||||
# X ~ low + Exp(1.0)
|
||||
# => X - low ~ Exp(1.0)
|
||||
dist = scipy.stats.expon
|
||||
sample = sample - low
|
||||
elif not bounded_below and bounded_above:
|
||||
# X ~ high - Exp(1.0)
|
||||
# => high - X ~ Exp(1.0)
|
||||
dist = scipy.stats.expon
|
||||
sample = high - sample
|
||||
else:
|
||||
# X ~ N(0.0, 1.0)
|
||||
dist = scipy.stats.norm
|
||||
|
||||
_, p_value = scipy.stats.kstest(sample, dist.cdf)
|
||||
assert p_value >= ALPHA
|
||||
|
||||
|
||||
def chi2_test(sample, low, high, bounded_below, bounded_above):
|
||||
"""Perform chi-squared test on the sample. Automatically picks the distribution
|
||||
to test against based on the bounds.
|
||||
"""
|
||||
(n_trials,) = sample.shape
|
||||
|
||||
if bounded_below and bounded_above:
|
||||
# X ~ U(low, high)
|
||||
degrees_of_freedom = high - low + 1
|
||||
observed_frequency = np.bincount(sample - low, minlength=degrees_of_freedom)
|
||||
assert observed_frequency.shape == (degrees_of_freedom,)
|
||||
expected_frequency = np.ones(degrees_of_freedom) * n_trials / degrees_of_freedom
|
||||
elif bounded_below and not bounded_above:
|
||||
# X ~ low + Geom(1 - e^-1)
|
||||
# => X - low ~ Geom(1 - e^-1)
|
||||
dist = scipy.stats.geom(1 - 1 / np.e)
|
||||
observed_frequency = np.bincount(sample - low)
|
||||
x = np.arange(len(observed_frequency))
|
||||
expected_frequency = dist.pmf(x + 1) * n_trials
|
||||
expected_frequency[-1] += n_trials - np.sum(expected_frequency)
|
||||
elif not bounded_below and bounded_above:
|
||||
# X ~ high - Geom(1 - e^-1)
|
||||
# => high - X ~ Geom(1 - e^-1)
|
||||
dist = scipy.stats.geom(1 - 1 / np.e)
|
||||
observed_frequency = np.bincount(high - sample)
|
||||
x = np.arange(len(observed_frequency))
|
||||
expected_frequency = dist.pmf(x + 1) * n_trials
|
||||
expected_frequency[-1] += n_trials - np.sum(expected_frequency)
|
||||
else:
|
||||
# X ~ floor(N(0.0, 1.0)
|
||||
# => pmf(x) = cdf(x + 1) - cdf(x)
|
||||
lowest = np.min(sample)
|
||||
observed_frequency = np.bincount(sample - lowest)
|
||||
|
||||
normal_dist = scipy.stats.norm(0, 1)
|
||||
x = lowest + np.arange(len(observed_frequency))
|
||||
expected_frequency = normal_dist.cdf(x + 1) - normal_dist.cdf(x)
|
||||
expected_frequency[0] += normal_dist.cdf(lowest)
|
||||
expected_frequency *= n_trials
|
||||
expected_frequency[-1] += n_trials - np.sum(expected_frequency)
|
||||
|
||||
assert observed_frequency.shape == expected_frequency.shape
|
||||
variance = np.sum(
|
||||
np.square(expected_frequency - observed_frequency) / expected_frequency
|
||||
)
|
||||
degrees_of_freedom = len(observed_frequency) - 1
|
||||
critical_value = scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
|
||||
|
||||
assert variance < critical_value
|
||||
|
||||
|
||||
def binary_chi2_test(sample, low, high, bounded_below, bounded_above):
|
||||
"""Perform Chi-squared test on boolean samples."""
|
||||
assert bounded_below
|
||||
assert bounded_above
|
||||
|
||||
(n_trials,) = sample.shape
|
||||
|
||||
if low == high == 0:
|
||||
assert np.all(sample == 0)
|
||||
elif low == high == 1:
|
||||
assert np.all(sample == 1)
|
||||
else:
|
||||
expected_frequency = n_trials / 2
|
||||
observed_frequency = np.sum(sample)
|
||||
|
||||
# we can be lazy in the variance as the np.square is symmetric for the 0 and 1 categories
|
||||
variance = (
|
||||
2 * np.square(observed_frequency - expected_frequency) / expected_frequency
|
||||
)
|
||||
|
||||
critical_value = scipy.stats.chi2.isf(ALPHA, df=1)
|
||||
assert variance < critical_value
|
||||
|
||||
|
||||
SAMPLE_MASK_RNG, _ = seeding.np_random(1)
|
||||
|
||||
|
||||
@@ -215,6 +322,9 @@ SAMPLE_MASK_RNG, _ = seeding.np_random(1)
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
# Multi-discrete
|
||||
(np.array([1, 1], dtype=np.int8), np.array([0, 0], dtype=np.int8)),
|
||||
(
|
||||
@@ -264,7 +374,10 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
|
||||
np.square(expected_frequency - observed_frequency)
|
||||
/ np.clip(expected_frequency, 1, None)
|
||||
)
|
||||
assert variance < CHI_SQUARED[degrees_of_freedom]
|
||||
if degrees_of_freedom == 0:
|
||||
assert variance == 0
|
||||
else:
|
||||
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
|
||||
elif isinstance(space, MultiBinary):
|
||||
expected_frequency = (
|
||||
np.ones(space.shape) * np.where(mask == 2, 0.5, mask) * n_trials
|
||||
@@ -279,7 +392,7 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
|
||||
/ np.clip(expected_frequency, 1, None)
|
||||
)
|
||||
assert variance.shape == space.shape
|
||||
assert np.all(variance < CHI_SQUARED[1])
|
||||
assert np.all(variance < scipy.stats.chi2.isf(ALPHA, df=1))
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
|
||||
def _generate_frequency(
|
||||
@@ -332,7 +445,13 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
|
||||
np.square(exp_freq - obs_freq) / np.clip(exp_freq, 1, None)
|
||||
)
|
||||
_degrees_of_freedom = max(np.sum(_mask) - 1, 0)
|
||||
assert _variance < CHI_SQUARED[_degrees_of_freedom]
|
||||
|
||||
if _degrees_of_freedom == 0:
|
||||
assert _variance == 0
|
||||
else:
|
||||
assert _variance < scipy.stats.chi2.isf(
|
||||
ALPHA, df=_degrees_of_freedom
|
||||
)
|
||||
|
||||
_chi_squared_test(space.nvec, mask, expected_frequency, observed_frequency)
|
||||
elif isinstance(space, Text):
|
||||
@@ -370,14 +489,11 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
|
||||
np.square(expected_frequency - observed_frequency)
|
||||
/ np.clip(expected_frequency, 1, None)
|
||||
)
|
||||
if degrees_of_freedom == 26:
|
||||
# scipy.stats.chi2.isf(0.05, df=29)
|
||||
assert variance < 38.88513865983007
|
||||
elif degrees_of_freedom == 31:
|
||||
# scipy.stats.chi2.isf(0.05, df=31)
|
||||
assert variance < 44.985343280365136
|
||||
|
||||
if degrees_of_freedom == 0:
|
||||
assert variance == 0
|
||||
else:
|
||||
assert variance < CHI_SQUARED[degrees_of_freedom]
|
||||
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@@ -20,6 +20,9 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
12,
|
||||
3,
|
||||
4,
|
||||
# Multi-discrete
|
||||
4,
|
||||
10,
|
||||
|
@@ -24,6 +24,13 @@ TESTING_FUNDAMENTAL_SPACES = [
|
||||
Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64),
|
||||
Box(low=-np.inf, high=0.0, shape=(2, 1)),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 1)),
|
||||
Box(low=0, high=255, shape=(2, 2, 3), dtype=np.uint8),
|
||||
Box(low=np.array([0, 0, 1]), high=np.array([1, 0, 1]), dtype=np.bool_),
|
||||
Box(
|
||||
low=np.array([-np.inf, -np.inf, 0, -10]),
|
||||
high=np.array([np.inf, 0, np.inf, 10]),
|
||||
dtype=np.int32,
|
||||
),
|
||||
MultiDiscrete([2, 2]),
|
||||
MultiDiscrete([[2, 3], [3, 2]]),
|
||||
MultiBinary(8),
|
||||
|
Reference in New Issue
Block a user