Readded overwritten changes for offset functionality for Discrete spaces (#2470)

Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
Ishan Manchanda
2021-10-30 21:42:01 +05:30
committed by GitHub
parent 531d4d02db
commit 103b7633f5
4 changed files with 36 additions and 7 deletions

View File

@@ -59,7 +59,7 @@ print(env.observation_space.low)
```
- There are multiple types of Space types inherently available in gym:
- `Box` describes an n-dimensional continuous space. Its a bounded space where we can define the upper and lower limit which describe the valid values our observations can take.
- `Discrete` describes a discrete space where { 0, 1, ......., n-1} are the possible values our observation/action can take.
- `Discrete` describes a discrete space where { 0, 1, ......., n-1} are the possible values our observation/action can take. Values can be shifted to { a, a+1, ......., a+n-1} using an optional argument.
- `Dict` represents a dictionary of simple spaces.
- `Tuple` represents a tuple of simple spaces
- `MultiBinary` creates a n-shape binary space. Argument n can be a number or a `list` of numbers
@@ -73,6 +73,10 @@ print(env.observation_space.low)
print(observation_space.sample())
#> 1
observation_space = Discrete(5, start=-2)
print(observation_space.sample())
#> -2
observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
print(observation_space.sample())
#> OrderedDict([('position', 0), ('velocity', 1)])

View File

@@ -6,6 +6,7 @@ from os import path
class PendulumEnv(gym.Env):
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 30}
def __init__(self, g=10.0):

View File

@@ -5,19 +5,24 @@ from .space import Space
class Discrete(Space):
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
A start value can be optionally specified to shift the range
to :math:`\{ a, a+1, \\dots, a+n-1 \}`.
Example::
>>> Discrete(2)
>>> Discrete(3, start=-1) # {-1, 0, 1}
"""
def __init__(self, n, seed=None):
assert n >= 0
def __init__(self, n, seed=None, start=0):
assert n >= 0 and isinstance(start, (int, np.integer))
self.n = n
self.start = int(start)
super(Discrete, self).__init__((), np.int64, seed)
def sample(self):
return self.np_random.randint(self.n)
return self.start + self.np_random.randint(self.n)
def contains(self, x):
if isinstance(x, int):
@@ -28,10 +33,16 @@ class Discrete(Space):
as_int = int(x)
else:
return False
return as_int >= 0 and as_int < self.n
return self.start <= as_int < self.start + self.n
def __repr__(self):
if self.start != 0:
return "Discrete(%d, start=%d)" % (self.n, self.start)
return "Discrete(%d)" % self.n
def __eq__(self, other):
return isinstance(other, Discrete) and self.n == other.n
return (
isinstance(other, Discrete)
and self.n == other.n
and self.start == other.start
)

View File

@@ -11,6 +11,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
"space",
[
Discrete(3),
Discrete(5, start=-2),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
@@ -20,6 +21,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
MultiDiscrete([2, 2, 100]),
MultiBinary(10),
Dict(
@@ -56,6 +58,7 @@ def test_roundtripping(space):
"space",
[
Discrete(3),
Discrete(5, start=-2),
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(1, 3)),
Tuple([Discrete(5), Discrete(10)]),
@@ -66,6 +69,7 @@ def test_roundtripping(space):
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
MultiDiscrete([2, 2, 100]),
MultiBinary(6),
Dict(
@@ -88,6 +92,7 @@ def test_equality(space):
"spaces",
[
(Discrete(3), Discrete(4)),
(Discrete(3), Discrete(3, start=-1)),
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
(MultiBinary(8), MultiBinary(7)),
(
@@ -99,6 +104,10 @@ def test_equality(space):
Box(low=0.0, high=np.inf, shape=(2, 1)),
),
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
(
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5, start=7), Discrete(10)]),
),
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
],
@@ -112,6 +121,7 @@ def test_inequality(spaces):
"space",
[
Discrete(5),
Discrete(8, start=-20),
Box(low=0, high=255, shape=(2,), dtype="uint8"),
Box(low=-np.inf, high=np.inf, shape=(3, 3)),
Box(low=1.0, high=np.inf, shape=(3, 3)),
@@ -133,7 +143,7 @@ def test_sample(space):
else:
expected_mean = 0.0
elif isinstance(space, Discrete):
expected_mean = space.n / 2
expected_mean = space.start + space.n / 2
else:
raise NotImplementedError
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
@@ -246,6 +256,7 @@ def test_box_dtype_check():
"space",
[
Discrete(3),
Discrete(3, start=-4),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
@@ -298,6 +309,7 @@ def sample_equal(sample1, sample2):
"space",
[
Discrete(3),
Discrete(3, start=-4),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
@@ -335,6 +347,7 @@ def test_seed_reproducibility(space):
[
Tuple([Discrete(100), Discrete(100)]),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Discrete(5, start=10)]),
Tuple(
[
Discrete(5),