mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Change discrete dtype to np.int64 (#141)
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
"""Implementation of a space consisting of finitely many elements."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Mapping
|
||||
from typing import Any, Iterable, Mapping, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.spaces.space import MaskNDArray, Space
|
||||
|
||||
|
||||
class Discrete(Space[int]):
|
||||
class Discrete(Space[np.int64]):
|
||||
r"""A space consisting of finitely many elements.
|
||||
|
||||
This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`.
|
||||
@@ -21,9 +21,9 @@ class Discrete(Space[int]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: int,
|
||||
n: int | np.integer[Any],
|
||||
seed: int | np.random.Generator | None = None,
|
||||
start: int = 0,
|
||||
start: int | np.integer[Any] = 0,
|
||||
):
|
||||
r"""Constructor of :class:`Discrete` space.
|
||||
|
||||
@@ -34,11 +34,16 @@ class Discrete(Space[int]):
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
|
||||
start (int): The smallest element of this space.
|
||||
"""
|
||||
assert isinstance(n, (int, np.integer))
|
||||
assert np.issubdtype(
|
||||
type(n), np.integer
|
||||
), f"Expects `n` to be an integer, actual dtype: {type(n)}"
|
||||
assert n > 0, "n (counts) have to be positive"
|
||||
assert isinstance(start, (int, np.integer))
|
||||
self.n = int(n)
|
||||
self.start = int(start)
|
||||
assert np.issubdtype(
|
||||
type(start), np.integer
|
||||
), f"Expects `start` to be an integer, actual type: {type(start)}"
|
||||
|
||||
self.n = np.int64(n)
|
||||
self.start = np.int64(start)
|
||||
super().__init__((), np.int64, seed)
|
||||
|
||||
@property
|
||||
@@ -74,26 +79,26 @@ class Discrete(Space[int]):
|
||||
np.logical_or(mask == 0, valid_action_mask)
|
||||
), f"All values of a mask should be 0 or 1, actual values: {mask}"
|
||||
if np.any(valid_action_mask):
|
||||
return int(
|
||||
self.start + self.np_random.choice(np.where(valid_action_mask)[0])
|
||||
return self.start + self.np_random.choice(
|
||||
np.where(valid_action_mask)[0]
|
||||
)
|
||||
else:
|
||||
return self.start
|
||||
|
||||
return int(self.start + self.np_random.integers(self.n))
|
||||
return self.start + self.np_random.integers(self.n)
|
||||
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, int):
|
||||
as_int = x
|
||||
as_int64 = np.int64(x)
|
||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
||||
np.issubdtype(x.dtype, np.integer) and x.shape == ()
|
||||
):
|
||||
as_int = int(x)
|
||||
as_int64 = np.int64(x)
|
||||
else:
|
||||
return False
|
||||
|
||||
return self.start <= as_int < self.start + self.n
|
||||
return bool(self.start <= as_int64 < self.start + self.n)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
@@ -123,6 +128,14 @@ class Discrete(Space[int]):
|
||||
# Allow for loading of legacy states.
|
||||
# See https://github.com/openai/gym/pull/2470
|
||||
if "start" not in state:
|
||||
state["start"] = 0
|
||||
state["start"] = np.int64(0)
|
||||
|
||||
super().__setstate__(state)
|
||||
|
||||
def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]:
|
||||
"""Converts a list of samples to a list of ints."""
|
||||
return [int(x) for x in sample_n]
|
||||
|
||||
def from_jsonable(self, sample_n: list[int]) -> list[np.int64]:
|
||||
"""Converts a list of json samples to a list of np.int64."""
|
||||
return [np.int64(x) for x in sample_n]
|
||||
|
@@ -160,7 +160,7 @@ def _flatten_box_multibinary(space: Box | MultiBinary, x: NDArray[Any]) -> NDArr
|
||||
|
||||
|
||||
@flatten.register(Discrete)
|
||||
def _flatten_discrete(space: Discrete, x: int) -> NDArray[np.int64]:
|
||||
def _flatten_discrete(space: Discrete, x: np.int64) -> NDArray[np.int64]:
|
||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||
onehot[x - space.start] = 1
|
||||
return onehot
|
||||
@@ -268,14 +268,14 @@ def _unflatten_box_multibinary(
|
||||
|
||||
|
||||
@unflatten.register(Discrete)
|
||||
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> int:
|
||||
return int(space.start + np.nonzero(x)[0][0])
|
||||
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
|
||||
return space.start + np.nonzero(x)[0][0]
|
||||
|
||||
|
||||
@unflatten.register(MultiDiscrete)
|
||||
def _unflatten_multidiscrete(
|
||||
space: MultiDiscrete, x: NDArray[np.int32]
|
||||
) -> NDArray[np.int32]:
|
||||
space: MultiDiscrete, x: NDArray[np.integer[Any]]
|
||||
) -> NDArray[np.integer[Any]]:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
|
Reference in New Issue
Block a user