mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 05:25:54 +00:00
Readded overwritten changes for offset functionality for Discrete spaces (#2470)
Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
@@ -59,7 +59,7 @@ print(env.observation_space.low)
|
||||
```
|
||||
- There are multiple types of Space types inherently available in gym:
|
||||
- `Box` describes an n-dimensional continuous space. Its a bounded space where we can define the upper and lower limit which describe the valid values our observations can take.
|
||||
- `Discrete` describes a discrete space where { 0, 1, ......., n-1} are the possible values our observation/action can take.
|
||||
- `Discrete` describes a discrete space where { 0, 1, ......., n-1} are the possible values our observation/action can take. Values can be shifted to { a, a+1, ......., a+n-1} using an optional argument.
|
||||
- `Dict` represents a dictionary of simple spaces.
|
||||
- `Tuple` represents a tuple of simple spaces
|
||||
- `MultiBinary` creates a n-shape binary space. Argument n can be a number or a `list` of numbers
|
||||
@@ -72,6 +72,10 @@ print(env.observation_space.low)
|
||||
observation_space = Discrete(4)
|
||||
print(observation_space.sample())
|
||||
#> 1
|
||||
|
||||
observation_space = Discrete(5, start=-2)
|
||||
print(observation_space.sample())
|
||||
#> -2
|
||||
|
||||
observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||
print(observation_space.sample())
|
||||
|
@@ -6,6 +6,7 @@ from os import path
|
||||
|
||||
|
||||
class PendulumEnv(gym.Env):
|
||||
|
||||
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 30}
|
||||
|
||||
def __init__(self, g=10.0):
|
||||
|
@@ -5,19 +5,24 @@ from .space import Space
|
||||
class Discrete(Space):
|
||||
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
|
||||
|
||||
A start value can be optionally specified to shift the range
|
||||
to :math:`\{ a, a+1, \\dots, a+n-1 \}`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> Discrete(2)
|
||||
>>> Discrete(3, start=-1) # {-1, 0, 1}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n, seed=None):
|
||||
assert n >= 0
|
||||
def __init__(self, n, seed=None, start=0):
|
||||
assert n >= 0 and isinstance(start, (int, np.integer))
|
||||
self.n = n
|
||||
self.start = int(start)
|
||||
super(Discrete, self).__init__((), np.int64, seed)
|
||||
|
||||
def sample(self):
|
||||
return self.np_random.randint(self.n)
|
||||
return self.start + self.np_random.randint(self.n)
|
||||
|
||||
def contains(self, x):
|
||||
if isinstance(x, int):
|
||||
@@ -28,10 +33,16 @@ class Discrete(Space):
|
||||
as_int = int(x)
|
||||
else:
|
||||
return False
|
||||
return as_int >= 0 and as_int < self.n
|
||||
return self.start <= as_int < self.start + self.n
|
||||
|
||||
def __repr__(self):
|
||||
if self.start != 0:
|
||||
return "Discrete(%d, start=%d)" % (self.n, self.start)
|
||||
return "Discrete(%d)" % self.n
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Discrete) and self.n == other.n
|
||||
return (
|
||||
isinstance(other, Discrete)
|
||||
and self.n == other.n
|
||||
and self.start == other.start
|
||||
)
|
||||
|
@@ -11,6 +11,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(5, start=-2),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
@@ -20,6 +21,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict(
|
||||
@@ -56,6 +58,7 @@ def test_roundtripping(space):
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(5, start=-2),
|
||||
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(1, 3)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
@@ -66,6 +69,7 @@ def test_roundtripping(space):
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(6),
|
||||
Dict(
|
||||
@@ -88,6 +92,7 @@ def test_equality(space):
|
||||
"spaces",
|
||||
[
|
||||
(Discrete(3), Discrete(4)),
|
||||
(Discrete(3), Discrete(3, start=-1)),
|
||||
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
||||
(MultiBinary(8), MultiBinary(7)),
|
||||
(
|
||||
@@ -99,6 +104,10 @@ def test_equality(space):
|
||||
Box(low=0.0, high=np.inf, shape=(2, 1)),
|
||||
),
|
||||
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
||||
(
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5, start=7), Discrete(10)]),
|
||||
),
|
||||
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
||||
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
||||
],
|
||||
@@ -112,6 +121,7 @@ def test_inequality(spaces):
|
||||
"space",
|
||||
[
|
||||
Discrete(5),
|
||||
Discrete(8, start=-20),
|
||||
Box(low=0, high=255, shape=(2,), dtype="uint8"),
|
||||
Box(low=-np.inf, high=np.inf, shape=(3, 3)),
|
||||
Box(low=1.0, high=np.inf, shape=(3, 3)),
|
||||
@@ -133,7 +143,7 @@ def test_sample(space):
|
||||
else:
|
||||
expected_mean = 0.0
|
||||
elif isinstance(space, Discrete):
|
||||
expected_mean = space.n / 2
|
||||
expected_mean = space.start + space.n / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
||||
@@ -246,6 +256,7 @@ def test_box_dtype_check():
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(3, start=-4),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
@@ -298,6 +309,7 @@ def sample_equal(sample1, sample2):
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(3, start=-4),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
@@ -335,6 +347,7 @@ def test_seed_reproducibility(space):
|
||||
[
|
||||
Tuple([Discrete(100), Discrete(100)]),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Discrete(5, start=10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
|
Reference in New Issue
Block a user