diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index a5d9754e6..e1ec521c9 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -1,14 +1,14 @@ """Implementation of a space consisting of finitely many elements.""" from __future__ import annotations -from typing import Any, Iterable, Mapping +from typing import Any, Iterable, Mapping, Sequence import numpy as np from gymnasium.spaces.space import MaskNDArray, Space -class Discrete(Space[int]): +class Discrete(Space[np.int64]): r"""A space consisting of finitely many elements. This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`. @@ -21,9 +21,9 @@ class Discrete(Space[int]): def __init__( self, - n: int, + n: int | np.integer[Any], seed: int | np.random.Generator | None = None, - start: int = 0, + start: int | np.integer[Any] = 0, ): r"""Constructor of :class:`Discrete` space. @@ -34,11 +34,16 @@ class Discrete(Space[int]): seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space. start (int): The smallest element of this space. """ - assert isinstance(n, (int, np.integer)) + assert np.issubdtype( + type(n), np.integer + ), f"Expects `n` to be an integer, actual dtype: {type(n)}" assert n > 0, "n (counts) have to be positive" - assert isinstance(start, (int, np.integer)) - self.n = int(n) - self.start = int(start) + assert np.issubdtype( + type(start), np.integer + ), f"Expects `start` to be an integer, actual type: {type(start)}" + + self.n = np.int64(n) + self.start = np.int64(start) super().__init__((), np.int64, seed) @property @@ -74,26 +79,26 @@ class Discrete(Space[int]): np.logical_or(mask == 0, valid_action_mask) ), f"All values of a mask should be 0 or 1, actual values: {mask}" if np.any(valid_action_mask): - return int( - self.start + self.np_random.choice(np.where(valid_action_mask)[0]) + return self.start + self.np_random.choice( + np.where(valid_action_mask)[0] ) else: return self.start - return int(self.start + self.np_random.integers(self.n)) + return self.start + self.np_random.integers(self.n) def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" if isinstance(x, int): - as_int = x + as_int64 = np.int64(x) elif isinstance(x, (np.generic, np.ndarray)) and ( np.issubdtype(x.dtype, np.integer) and x.shape == () ): - as_int = int(x) + as_int64 = np.int64(x) else: return False - return self.start <= as_int < self.start + self.n + return bool(self.start <= as_int64 < self.start + self.n) def __repr__(self) -> str: """Gives a string representation of this space.""" @@ -123,6 +128,14 @@ class Discrete(Space[int]): # Allow for loading of legacy states. # See https://github.com/openai/gym/pull/2470 if "start" not in state: - state["start"] = 0 + state["start"] = np.int64(0) super().__setstate__(state) + + def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]: + """Converts a list of samples to a list of ints.""" + return [int(x) for x in sample_n] + + def from_jsonable(self, sample_n: list[int]) -> list[np.int64]: + """Converts a list of json samples to a list of np.int64.""" + return [np.int64(x) for x in sample_n] diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 954a96d68..aef2ea268 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -160,7 +160,7 @@ def _flatten_box_multibinary(space: Box | MultiBinary, x: NDArray[Any]) -> NDArr @flatten.register(Discrete) -def _flatten_discrete(space: Discrete, x: int) -> NDArray[np.int64]: +def _flatten_discrete(space: Discrete, x: np.int64) -> NDArray[np.int64]: onehot = np.zeros(space.n, dtype=space.dtype) onehot[x - space.start] = 1 return onehot @@ -268,14 +268,14 @@ def _unflatten_box_multibinary( @unflatten.register(Discrete) -def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> int: - return int(space.start + np.nonzero(x)[0][0]) +def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64: + return space.start + np.nonzero(x)[0][0] @unflatten.register(MultiDiscrete) def _unflatten_multidiscrete( - space: MultiDiscrete, x: NDArray[np.int32] -) -> NDArray[np.int32]: + space: MultiDiscrete, x: NDArray[np.integer[Any]] +) -> NDArray[np.integer[Any]]: offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) offsets[1:] = np.cumsum(space.nvec.flatten())