mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Make MultiDiscrete
a Tuple
-like space (#2364)
* Make MultiDiscrete a Tuple-like space * Update test cases for MultiDiscrete
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import numpy as np
|
||||
from gym.logger import warn
|
||||
from .space import Space
|
||||
from .discrete import Discrete
|
||||
|
||||
|
||||
class MultiDiscrete(Space):
|
||||
@@ -24,7 +26,6 @@ class MultiDiscrete(Space):
|
||||
"""
|
||||
|
||||
def __init__(self, nvec, dtype=np.int64):
|
||||
|
||||
"""
|
||||
nvec: vector of counts of each categorical variable
|
||||
"""
|
||||
@@ -54,5 +55,19 @@ class MultiDiscrete(Space):
|
||||
def __repr__(self):
|
||||
return "MultiDiscrete({})".format(self.nvec)
|
||||
|
||||
def __getitem__(self, index):
|
||||
nvec = self.nvec[index]
|
||||
if nvec.ndim == 0:
|
||||
subspace = Discrete(nvec)
|
||||
else:
|
||||
subspace = MultiDiscrete(nvec, self.dtype)
|
||||
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
|
||||
return subspace
|
||||
|
||||
def __len__(self):
|
||||
if self.nvec.ndim >= 2:
|
||||
warn("Get length of a multi-dimensional MultiDiscrete space.")
|
||||
return len(self.nvec)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||
|
Reference in New Issue
Block a user