Change discrete dtype to np.int64 (#141)

This commit is contained in:
Mark Towers
2022-11-29 14:57:46 +00:00
committed by GitHub
parent df811e7d54
commit 6f139cdec5
2 changed files with 33 additions and 20 deletions

View File

@@ -1,14 +1,14 @@
"""Implementation of a space consisting of finitely many elements."""
from __future__ import annotations
from typing import Any, Iterable, Mapping
from typing import Any, Iterable, Mapping, Sequence
import numpy as np
from gymnasium.spaces.space import MaskNDArray, Space
class Discrete(Space[int]):
class Discrete(Space[np.int64]):
r"""A space consisting of finitely many elements.
This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`.
@@ -21,9 +21,9 @@ class Discrete(Space[int]):
def __init__(
self,
n: int,
n: int | np.integer[Any],
seed: int | np.random.Generator | None = None,
start: int = 0,
start: int | np.integer[Any] = 0,
):
r"""Constructor of :class:`Discrete` space.
@@ -34,11 +34,16 @@ class Discrete(Space[int]):
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
start (int): The smallest element of this space.
"""
assert isinstance(n, (int, np.integer))
assert np.issubdtype(
type(n), np.integer
), f"Expects `n` to be an integer, actual dtype: {type(n)}"
assert n > 0, "n (counts) have to be positive"
assert isinstance(start, (int, np.integer))
self.n = int(n)
self.start = int(start)
assert np.issubdtype(
type(start), np.integer
), f"Expects `start` to be an integer, actual type: {type(start)}"
self.n = np.int64(n)
self.start = np.int64(start)
super().__init__((), np.int64, seed)
@property
@@ -74,26 +79,26 @@ class Discrete(Space[int]):
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
if np.any(valid_action_mask):
return int(
self.start + self.np_random.choice(np.where(valid_action_mask)[0])
return self.start + self.np_random.choice(
np.where(valid_action_mask)[0]
)
else:
return self.start
return int(self.start + self.np_random.integers(self.n))
return self.start + self.np_random.integers(self.n)
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, int):
as_int = x
as_int64 = np.int64(x)
elif isinstance(x, (np.generic, np.ndarray)) and (
np.issubdtype(x.dtype, np.integer) and x.shape == ()
):
as_int = int(x)
as_int64 = np.int64(x)
else:
return False
return self.start <= as_int < self.start + self.n
return bool(self.start <= as_int64 < self.start + self.n)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
@@ -123,6 +128,14 @@ class Discrete(Space[int]):
# Allow for loading of legacy states.
# See https://github.com/openai/gym/pull/2470
if "start" not in state:
state["start"] = 0
state["start"] = np.int64(0)
super().__setstate__(state)
def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]:
"""Converts a list of samples to a list of ints."""
return [int(x) for x in sample_n]
def from_jsonable(self, sample_n: list[int]) -> list[np.int64]:
"""Converts a list of json samples to a list of np.int64."""
return [np.int64(x) for x in sample_n]

View File

@@ -160,7 +160,7 @@ def _flatten_box_multibinary(space: Box | MultiBinary, x: NDArray[Any]) -> NDArr
@flatten.register(Discrete)
def _flatten_discrete(space: Discrete, x: int) -> NDArray[np.int64]:
def _flatten_discrete(space: Discrete, x: np.int64) -> NDArray[np.int64]:
onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x - space.start] = 1
return onehot
@@ -268,14 +268,14 @@ def _unflatten_box_multibinary(
@unflatten.register(Discrete)
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> int:
return int(space.start + np.nonzero(x)[0][0])
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
return space.start + np.nonzero(x)[0][0]
@unflatten.register(MultiDiscrete)
def _unflatten_multidiscrete(
space: MultiDiscrete, x: NDArray[np.int32]
) -> NDArray[np.int32]:
space: MultiDiscrete, x: NDArray[np.integer[Any]]
) -> NDArray[np.integer[Any]]:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten())