Make MultiDiscrete a Tuple-like space (#2364)

* Make MultiDiscrete a Tuple-like space

* Update test cases for MultiDiscrete
This commit is contained in:
Xuehai Pan
2021-09-12 00:54:52 +08:00
committed by GitHub
parent 8da6224b72
commit a8f551ed44
2 changed files with 72 additions and 5 deletions

View File

@@ -1,5 +1,7 @@
import numpy as np
from gym.logger import warn
from .space import Space
from .discrete import Discrete
class MultiDiscrete(Space):
@@ -24,7 +26,6 @@ class MultiDiscrete(Space):
"""
def __init__(self, nvec, dtype=np.int64):
"""
nvec: vector of counts of each categorical variable
"""
@@ -54,5 +55,19 @@ class MultiDiscrete(Space):
def __repr__(self):
return "MultiDiscrete({})".format(self.nvec)
def __getitem__(self, index):
nvec = self.nvec[index]
if nvec.ndim == 0:
subspace = Discrete(nvec)
else:
subspace = MultiDiscrete(nvec, self.dtype)
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
return subspace
def __len__(self):
if self.nvec.ndim >= 2:
warn("Get length of a multi-dimensional MultiDiscrete space.")
return len(self.nvec)
def __eq__(self, other):
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)

View File

@@ -1,6 +1,5 @@
import json # note: ujson fails this test due to float equality
import copy
from collections import OrderedDict
import numpy as np
import pytest
@@ -244,6 +243,10 @@ def convert_sample_hashable(sample):
return sample
def sample_equal(sample1, sample2):
return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
@pytest.mark.parametrize(
"space",
[
@@ -277,9 +280,7 @@ def test_seed_reproducibility(space):
space2.seed(None)
assert space1.seed(0) == space2.seed(0)
sample1, sample2 = space1.sample(), space2.sample()
assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
assert sample_equal(space1.sample(), space2.sample())
@pytest.mark.parametrize(
@@ -314,3 +315,54 @@ def test_seed_subspace_incorrelated(space):
]
assert len(states) == len(set(states))
def test_multidiscrete_as_tuple():
# 1D multi-discrete
space = MultiDiscrete([3, 4, 5])
assert space.shape == (3,)
assert space[0] == Discrete(3)
assert space[0:1] == MultiDiscrete([3])
assert space[0:2] == MultiDiscrete([3, 4])
assert space[:] == space and space[:] is not space
assert len(space) == 3
# 2D multi-discrete
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
assert space.shape == (2, 3)
assert space[0, 1] == Discrete(4)
assert space[0] == MultiDiscrete([3, 4, 5])
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
assert space[:] == space and space[:] is not space
assert space[:, :] == space and space[:, :] is not space
def test_multidiscrete_subspace_reproducibility():
# 1D multi-discrete
space = MultiDiscrete([100, 200, 300])
space.seed(None)
assert sample_equal(space[0].sample(), space[0].sample())
assert sample_equal(space[0:1].sample(), space[0:1].sample())
assert sample_equal(space[0:2].sample(), space[0:2].sample())
assert sample_equal(space[:].sample(), space[:].sample())
assert sample_equal(space[:].sample(), space.sample())
# 2D multi-discrete
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
space.seed(None)
assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
assert sample_equal(space[0].sample(), space[0].sample())
assert sample_equal(space[0:1].sample(), space[0:1].sample())
assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
assert sample_equal(space[:].sample(), space[:].sample())
assert sample_equal(space[:, :].sample(), space[:, :].sample())
assert sample_equal(space[:, :].sample(), space.sample())