Check that MultiDiscrete.dtype is not None (#1196)

This commit is contained in:
JDRanpariya
2024-10-06 16:39:58 +02:00
committed by GitHub
parent 175202f9b3
commit 196625488f

View File

@@ -59,6 +59,19 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
start: Optionally, the starting value the element of each class will take (defaults to 0).
"""
# determine dtype
if dtype is None:
raise ValueError(
"MultiDiscrete dtype must be explicitly provided, cannot be None."
)
self.dtype = np.dtype(dtype)
# * check that dtype is an accepted dtype
if not (np.issubdtype(self.dtype, np.integer)):
raise ValueError(
f"Invalid MultiDiscrete dtype ({self.dtype}), must be an integer dtype"
)
self.nvec = np.array(nvec, dtype=dtype, copy=True)
if start is not None:
self.start = np.array(start, dtype=dtype, copy=True)
@@ -70,7 +83,7 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
), "start and nvec (counts) should have the same shape"
assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
super().__init__(self.nvec.shape, dtype, seed)
super().__init__(self.nvec.shape, self.dtype, seed)
@property
def shape(self) -> tuple[int, ...]: