mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 23:12:46 +00:00
Check that MultiDiscrete.dtype
is not None (#1196)
This commit is contained in:
@@ -59,6 +59,19 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
|
|||||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||||
start: Optionally, the starting value the element of each class will take (defaults to 0).
|
start: Optionally, the starting value the element of each class will take (defaults to 0).
|
||||||
"""
|
"""
|
||||||
|
# determine dtype
|
||||||
|
if dtype is None:
|
||||||
|
raise ValueError(
|
||||||
|
"MultiDiscrete dtype must be explicitly provided, cannot be None."
|
||||||
|
)
|
||||||
|
self.dtype = np.dtype(dtype)
|
||||||
|
|
||||||
|
# * check that dtype is an accepted dtype
|
||||||
|
if not (np.issubdtype(self.dtype, np.integer)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid MultiDiscrete dtype ({self.dtype}), must be an integer dtype"
|
||||||
|
)
|
||||||
|
|
||||||
self.nvec = np.array(nvec, dtype=dtype, copy=True)
|
self.nvec = np.array(nvec, dtype=dtype, copy=True)
|
||||||
if start is not None:
|
if start is not None:
|
||||||
self.start = np.array(start, dtype=dtype, copy=True)
|
self.start = np.array(start, dtype=dtype, copy=True)
|
||||||
@@ -70,7 +83,7 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
|
|||||||
), "start and nvec (counts) should have the same shape"
|
), "start and nvec (counts) should have the same shape"
|
||||||
assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
|
assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
|
||||||
|
|
||||||
super().__init__(self.nvec.shape, dtype, seed)
|
super().__init__(self.nvec.shape, self.dtype, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
|
Reference in New Issue
Block a user