mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 15:11:51 +00:00
Make MultiDiscrete
a Tuple
-like space (#2364)
* Make MultiDiscrete a Tuple-like space * Update test cases for MultiDiscrete
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gym.logger import warn
|
||||||
from .space import Space
|
from .space import Space
|
||||||
|
from .discrete import Discrete
|
||||||
|
|
||||||
|
|
||||||
class MultiDiscrete(Space):
|
class MultiDiscrete(Space):
|
||||||
@@ -24,7 +26,6 @@ class MultiDiscrete(Space):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, nvec, dtype=np.int64):
|
def __init__(self, nvec, dtype=np.int64):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
nvec: vector of counts of each categorical variable
|
nvec: vector of counts of each categorical variable
|
||||||
"""
|
"""
|
||||||
@@ -54,5 +55,19 @@ class MultiDiscrete(Space):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "MultiDiscrete({})".format(self.nvec)
|
return "MultiDiscrete({})".format(self.nvec)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
nvec = self.nvec[index]
|
||||||
|
if nvec.ndim == 0:
|
||||||
|
subspace = Discrete(nvec)
|
||||||
|
else:
|
||||||
|
subspace = MultiDiscrete(nvec, self.dtype)
|
||||||
|
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
|
||||||
|
return subspace
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.nvec.ndim >= 2:
|
||||||
|
warn("Get length of a multi-dimensional MultiDiscrete space.")
|
||||||
|
return len(self.nvec)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
import json # note: ujson fails this test due to float equality
|
import json # note: ujson fails this test due to float equality
|
||||||
import copy
|
import copy
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -244,6 +243,10 @@ def convert_sample_hashable(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def sample_equal(sample1, sample2):
|
||||||
|
return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"space",
|
"space",
|
||||||
[
|
[
|
||||||
@@ -277,9 +280,7 @@ def test_seed_reproducibility(space):
|
|||||||
space2.seed(None)
|
space2.seed(None)
|
||||||
|
|
||||||
assert space1.seed(0) == space2.seed(0)
|
assert space1.seed(0) == space2.seed(0)
|
||||||
|
assert sample_equal(space1.sample(), space2.sample())
|
||||||
sample1, sample2 = space1.sample(), space2.sample()
|
|
||||||
assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -314,3 +315,54 @@ def test_seed_subspace_incorrelated(space):
|
|||||||
]
|
]
|
||||||
|
|
||||||
assert len(states) == len(set(states))
|
assert len(states) == len(set(states))
|
||||||
|
|
||||||
|
|
||||||
|
def test_multidiscrete_as_tuple():
|
||||||
|
# 1D multi-discrete
|
||||||
|
space = MultiDiscrete([3, 4, 5])
|
||||||
|
|
||||||
|
assert space.shape == (3,)
|
||||||
|
assert space[0] == Discrete(3)
|
||||||
|
assert space[0:1] == MultiDiscrete([3])
|
||||||
|
assert space[0:2] == MultiDiscrete([3, 4])
|
||||||
|
assert space[:] == space and space[:] is not space
|
||||||
|
assert len(space) == 3
|
||||||
|
|
||||||
|
# 2D multi-discrete
|
||||||
|
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
||||||
|
|
||||||
|
assert space.shape == (2, 3)
|
||||||
|
assert space[0, 1] == Discrete(4)
|
||||||
|
assert space[0] == MultiDiscrete([3, 4, 5])
|
||||||
|
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
|
||||||
|
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
||||||
|
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
|
||||||
|
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
|
||||||
|
assert space[:] == space and space[:] is not space
|
||||||
|
assert space[:, :] == space and space[:, :] is not space
|
||||||
|
|
||||||
|
|
||||||
|
def test_multidiscrete_subspace_reproducibility():
|
||||||
|
# 1D multi-discrete
|
||||||
|
space = MultiDiscrete([100, 200, 300])
|
||||||
|
space.seed(None)
|
||||||
|
|
||||||
|
assert sample_equal(space[0].sample(), space[0].sample())
|
||||||
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
||||||
|
assert sample_equal(space[0:2].sample(), space[0:2].sample())
|
||||||
|
assert sample_equal(space[:].sample(), space[:].sample())
|
||||||
|
assert sample_equal(space[:].sample(), space.sample())
|
||||||
|
|
||||||
|
# 2D multi-discrete
|
||||||
|
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
|
||||||
|
space.seed(None)
|
||||||
|
|
||||||
|
assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
|
||||||
|
assert sample_equal(space[0].sample(), space[0].sample())
|
||||||
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
||||||
|
assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
|
||||||
|
assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
|
||||||
|
assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
|
||||||
|
assert sample_equal(space[:].sample(), space[:].sample())
|
||||||
|
assert sample_equal(space[:, :].sample(), space[:, :].sample())
|
||||||
|
assert sample_equal(space[:, :].sample(), space.sample())
|
||||||
|
Reference in New Issue
Block a user