mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
Check that MultiDiscrete.dtype
is not None (#1196)
This commit is contained in:
@@ -59,6 +59,19 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||
start: Optionally, the starting value the element of each class will take (defaults to 0).
|
||||
"""
|
||||
# determine dtype
|
||||
if dtype is None:
|
||||
raise ValueError(
|
||||
"MultiDiscrete dtype must be explicitly provided, cannot be None."
|
||||
)
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
# * check that dtype is an accepted dtype
|
||||
if not (np.issubdtype(self.dtype, np.integer)):
|
||||
raise ValueError(
|
||||
f"Invalid MultiDiscrete dtype ({self.dtype}), must be an integer dtype"
|
||||
)
|
||||
|
||||
self.nvec = np.array(nvec, dtype=dtype, copy=True)
|
||||
if start is not None:
|
||||
self.start = np.array(start, dtype=dtype, copy=True)
|
||||
@@ -70,7 +83,7 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
|
||||
), "start and nvec (counts) should have the same shape"
|
||||
assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
|
||||
|
||||
super().__init__(self.nvec.shape, dtype, seed)
|
||||
super().__init__(self.nvec.shape, self.dtype, seed)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
|
Reference in New Issue
Block a user