mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 15:11:51 +00:00
Make MultiDiscrete
a Tuple
-like space (#2364)
* Make MultiDiscrete a Tuple-like space * Update test cases for MultiDiscrete
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import numpy as np
|
||||
from gym.logger import warn
|
||||
from .space import Space
|
||||
from .discrete import Discrete
|
||||
|
||||
|
||||
class MultiDiscrete(Space):
|
||||
@@ -24,7 +26,6 @@ class MultiDiscrete(Space):
|
||||
"""
|
||||
|
||||
def __init__(self, nvec, dtype=np.int64):
|
||||
|
||||
"""
|
||||
nvec: vector of counts of each categorical variable
|
||||
"""
|
||||
@@ -54,5 +55,19 @@ class MultiDiscrete(Space):
|
||||
def __repr__(self):
|
||||
return "MultiDiscrete({})".format(self.nvec)
|
||||
|
||||
def __getitem__(self, index):
|
||||
nvec = self.nvec[index]
|
||||
if nvec.ndim == 0:
|
||||
subspace = Discrete(nvec)
|
||||
else:
|
||||
subspace = MultiDiscrete(nvec, self.dtype)
|
||||
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
|
||||
return subspace
|
||||
|
||||
def __len__(self):
|
||||
if self.nvec.ndim >= 2:
|
||||
warn("Get length of a multi-dimensional MultiDiscrete space.")
|
||||
return len(self.nvec)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import json # note: ujson fails this test due to float equality
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -244,6 +243,10 @@ def convert_sample_hashable(sample):
|
||||
return sample
|
||||
|
||||
|
||||
def sample_equal(sample1, sample2):
|
||||
return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
[
|
||||
@@ -277,9 +280,7 @@ def test_seed_reproducibility(space):
|
||||
space2.seed(None)
|
||||
|
||||
assert space1.seed(0) == space2.seed(0)
|
||||
|
||||
sample1, sample2 = space1.sample(), space2.sample()
|
||||
assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
||||
assert sample_equal(space1.sample(), space2.sample())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -314,3 +315,54 @@ def test_seed_subspace_incorrelated(space):
|
||||
]
|
||||
|
||||
assert len(states) == len(set(states))
|
||||
|
||||
|
||||
def test_multidiscrete_as_tuple():
|
||||
# 1D multi-discrete
|
||||
space = MultiDiscrete([3, 4, 5])
|
||||
|
||||
assert space.shape == (3,)
|
||||
assert space[0] == Discrete(3)
|
||||
assert space[0:1] == MultiDiscrete([3])
|
||||
assert space[0:2] == MultiDiscrete([3, 4])
|
||||
assert space[:] == space and space[:] is not space
|
||||
assert len(space) == 3
|
||||
|
||||
# 2D multi-discrete
|
||||
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
||||
|
||||
assert space.shape == (2, 3)
|
||||
assert space[0, 1] == Discrete(4)
|
||||
assert space[0] == MultiDiscrete([3, 4, 5])
|
||||
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
|
||||
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
||||
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
|
||||
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
|
||||
assert space[:] == space and space[:] is not space
|
||||
assert space[:, :] == space and space[:, :] is not space
|
||||
|
||||
|
||||
def test_multidiscrete_subspace_reproducibility():
|
||||
# 1D multi-discrete
|
||||
space = MultiDiscrete([100, 200, 300])
|
||||
space.seed(None)
|
||||
|
||||
assert sample_equal(space[0].sample(), space[0].sample())
|
||||
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
||||
assert sample_equal(space[0:2].sample(), space[0:2].sample())
|
||||
assert sample_equal(space[:].sample(), space[:].sample())
|
||||
assert sample_equal(space[:].sample(), space.sample())
|
||||
|
||||
# 2D multi-discrete
|
||||
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
|
||||
space.seed(None)
|
||||
|
||||
assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
|
||||
assert sample_equal(space[0].sample(), space[0].sample())
|
||||
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
||||
assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
|
||||
assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
|
||||
assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
|
||||
assert sample_equal(space[:].sample(), space[:].sample())
|
||||
assert sample_equal(space[:, :].sample(), space[:, :].sample())
|
||||
assert sample_equal(space[:, :].sample(), space.sample())
|
||||
|
Reference in New Issue
Block a user