Make MultiDiscrete a Tuple-like space (#2364)

* Make MultiDiscrete a Tuple-like space

* Update test cases for MultiDiscrete
This commit is contained in:
Xuehai Pan
2021-09-12 00:54:52 +08:00
committed by GitHub
parent 8da6224b72
commit a8f551ed44
2 changed files with 72 additions and 5 deletions

View File

@@ -1,5 +1,7 @@
import numpy as np
from gym.logger import warn
from .space import Space
from .discrete import Discrete
class MultiDiscrete(Space):
@@ -24,7 +26,6 @@ class MultiDiscrete(Space):
"""
def __init__(self, nvec, dtype=np.int64):
"""
nvec: vector of counts of each categorical variable
"""
@@ -54,5 +55,19 @@ class MultiDiscrete(Space):
def __repr__(self):
return "MultiDiscrete({})".format(self.nvec)
def __getitem__(self, index):
nvec = self.nvec[index]
if nvec.ndim == 0:
subspace = Discrete(nvec)
else:
subspace = MultiDiscrete(nvec, self.dtype)
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
return subspace
def __len__(self):
if self.nvec.ndim >= 2:
warn("Get length of a multi-dimensional MultiDiscrete space.")
return len(self.nvec)
def __eq__(self, other):
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)