mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-20 05:52:03 +00:00
typing in gym.spaces (#2541)
* typing in spaces.Box and spaces.Discrete * adds typing to dict and tuple spaces * Typecheck all spaces * Explicit regex to include all files under space folder * Style: use native types and __future__ annotations * Allow only specific strings for Box.is_bounded args * Add typing to changes from #2517 * Remove Literal as it's not supported by py3.7 * Use more recent version of pyright * Avoid name clash for type checker * Revert "Avoid name clash for type checker" This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1. * Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852 * rebase and add typing for `_short_repr`
This commit is contained in:
2
.github/workflows/lint_python.yml
vendored
2
.github/workflows/lint_python.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
python-version: [ "3.7"]
|
||||
fail-fast: false
|
||||
env:
|
||||
PYRIGHT_VERSION: 1.1.183
|
||||
PYRIGHT_VERSION: 1.1.204
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
|
@@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple, SupportsFloat, Union, Type, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .space import Space
|
||||
from gym import logger
|
||||
|
||||
|
||||
def _short_repr(arr):
|
||||
def _short_repr(arr: np.ndarray) -> str:
|
||||
"""Create a shortened string representation of a numpy array.
|
||||
|
||||
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
||||
@@ -15,7 +19,7 @@ def _short_repr(arr):
|
||||
return str(arr)
|
||||
|
||||
|
||||
class Box(Space):
|
||||
class Box(Space[np.ndarray]):
|
||||
"""
|
||||
A (possibly unbounded) box in R^n. Specifically, a Box represents the
|
||||
Cartesian product of n closed intervals. Each interval has the form of one
|
||||
@@ -33,66 +37,47 @@ class Box(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, low, high, shape=None, dtype=np.float32, seed=None):
|
||||
def __init__(
|
||||
self,
|
||||
low: Union[SupportsFloat, np.ndarray],
|
||||
high: Union[SupportsFloat, np.ndarray],
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: Type = np.float32,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
assert dtype is not None, "dtype must be explicitly provided. "
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
# determine shape if it isn't provided directly
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
assert (
|
||||
np.isscalar(low) or low.shape == shape
|
||||
), "low.shape doesn't match provided shape"
|
||||
assert (
|
||||
np.isscalar(high) or high.shape == shape
|
||||
), "high.shape doesn't match provided shape"
|
||||
elif not np.isscalar(low):
|
||||
shape = low.shape
|
||||
assert (
|
||||
np.isscalar(high) or high.shape == shape
|
||||
), "high.shape doesn't match low.shape"
|
||||
shape = low.shape # type: ignore
|
||||
elif not np.isscalar(high):
|
||||
shape = high.shape
|
||||
assert (
|
||||
np.isscalar(low) or low.shape == shape
|
||||
), "low.shape doesn't match high.shape"
|
||||
shape = high.shape # type: ignore
|
||||
else:
|
||||
raise ValueError(
|
||||
"shape must be provided or inferred from the shapes of low or high"
|
||||
)
|
||||
assert isinstance(shape, tuple)
|
||||
|
||||
# handle infinite bounds and broadcast at the same time if needed
|
||||
if np.isscalar(low):
|
||||
low = get_inf(dtype, "-") if np.isinf(low) else low
|
||||
low = np.full(shape, low, dtype=dtype)
|
||||
else:
|
||||
if np.any(np.isinf(low)):
|
||||
# create new array with dtype, but maintain old one to preserve np.inf
|
||||
temp_low = low.astype(dtype)
|
||||
temp_low[np.isinf(low)] = get_inf(dtype, "-")
|
||||
low = temp_low
|
||||
low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
|
||||
high = _broadcast(high, dtype, shape, inf_sign="+")
|
||||
|
||||
if np.isscalar(high):
|
||||
high = get_inf(dtype, "+") if np.isinf(high) else high
|
||||
high = np.full(shape, high, dtype=dtype)
|
||||
else:
|
||||
if np.any(np.isinf(high)):
|
||||
# create new array with dtype, but maintain old one to preserve np.inf
|
||||
temp_high = high.astype(dtype)
|
||||
temp_high[np.isinf(high)] = get_inf(dtype, "+")
|
||||
high = temp_high
|
||||
assert isinstance(low, np.ndarray)
|
||||
assert low.shape == shape, "low.shape doesn't match provided shape"
|
||||
assert isinstance(high, np.ndarray)
|
||||
assert high.shape == shape, "high.shape doesn't match provided shape"
|
||||
|
||||
self._shape = shape
|
||||
self.low = low
|
||||
self.high = high
|
||||
self._shape: Tuple[int, ...] = shape
|
||||
|
||||
low_precision = get_precision(self.low.dtype)
|
||||
high_precision = get_precision(self.high.dtype)
|
||||
low_precision = get_precision(low.dtype)
|
||||
high_precision = get_precision(high.dtype)
|
||||
dtype_precision = get_precision(self.dtype)
|
||||
if min(low_precision, high_precision) > dtype_precision:
|
||||
if min(low_precision, high_precision) > dtype_precision: # type: ignore
|
||||
logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
|
||||
self.low = self.low.astype(self.dtype)
|
||||
self.high = self.high.astype(self.dtype)
|
||||
self.low = low.astype(self.dtype)
|
||||
self.high = high.astype(self.dtype)
|
||||
|
||||
self.low_repr = _short_repr(self.low)
|
||||
self.high_repr = _short_repr(self.high)
|
||||
@@ -103,9 +88,14 @@ class Box(Space):
|
||||
|
||||
super().__init__(self.shape, self.dtype, seed)
|
||||
|
||||
def is_bounded(self, manner="both"):
|
||||
below = np.all(self.bounded_below)
|
||||
above = np.all(self.bounded_above)
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
return self._shape
|
||||
|
||||
def is_bounded(self, manner: str = "both") -> bool:
|
||||
below = bool(np.all(self.bounded_below))
|
||||
above = bool(np.all(self.bounded_above))
|
||||
if manner == "both":
|
||||
return below and above
|
||||
elif manner == "below":
|
||||
@@ -115,7 +105,7 @@ class Box(Space):
|
||||
else:
|
||||
raise ValueError("manner is not in {'below', 'above', 'both'}")
|
||||
|
||||
def sample(self):
|
||||
def sample(self) -> np.ndarray:
|
||||
"""
|
||||
Generates a single random sample inside of the Box.
|
||||
|
||||
@@ -158,12 +148,12 @@ class Box(Space):
|
||||
|
||||
return sample.astype(self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
def contains(self, x) -> bool:
|
||||
if not isinstance(x, np.ndarray):
|
||||
logger.warn("Casting input x to numpy array.")
|
||||
x = np.asarray(x, dtype=self.dtype)
|
||||
|
||||
return (
|
||||
return bool(
|
||||
np.can_cast(x.dtype, self.dtype)
|
||||
and x.shape == self.shape
|
||||
and np.all(x >= self.low)
|
||||
@@ -173,13 +163,13 @@ class Box(Space):
|
||||
def to_jsonable(self, sample_n):
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> list[np.ndarray]:
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return (
|
||||
isinstance(other, Box)
|
||||
and (self.shape == other.shape)
|
||||
@@ -188,7 +178,7 @@ class Box(Space):
|
||||
)
|
||||
|
||||
|
||||
def get_inf(dtype, sign):
|
||||
def get_inf(dtype, sign: str) -> SupportsFloat:
|
||||
"""Returns an infinite that doesn't break things.
|
||||
`dtype` must be an `np.dtype`
|
||||
`bound` must be either `min` or `max`
|
||||
@@ -211,8 +201,28 @@ def get_inf(dtype, sign):
|
||||
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
|
||||
|
||||
|
||||
def get_precision(dtype):
|
||||
def get_precision(dtype) -> SupportsFloat:
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
return np.finfo(dtype).precision
|
||||
else:
|
||||
return np.inf
|
||||
|
||||
|
||||
def _broadcast(
|
||||
value: Union[SupportsFloat, np.ndarray],
|
||||
dtype,
|
||||
shape: tuple[int, ...],
|
||||
inf_sign: str,
|
||||
) -> np.ndarray:
|
||||
"""handle infinite bounds and broadcast at the same time if needed"""
|
||||
if np.isscalar(value):
|
||||
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
|
||||
value = np.full(shape, value, dtype=dtype)
|
||||
else:
|
||||
assert isinstance(value, np.ndarray)
|
||||
if np.any(np.isinf(value)):
|
||||
# create new array with dtype, but maintain old one to preserve np.inf
|
||||
temp = value.astype(dtype)
|
||||
temp[np.isinf(value)] = get_inf(dtype, inf_sign)
|
||||
value = temp
|
||||
return value
|
||||
|
@@ -1,10 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Dict as TypingDict
|
||||
import numpy as np
|
||||
from .space import Space
|
||||
|
||||
|
||||
class Dict(Space, Mapping):
|
||||
class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
"""
|
||||
A dictionary of simpler spaces.
|
||||
|
||||
@@ -34,7 +37,12 @@ class Dict(Space, Mapping):
|
||||
})
|
||||
"""
|
||||
|
||||
def __init__(self, spaces=None, seed=None, **spaces_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
spaces: dict[str, Space] | None = None,
|
||||
seed: dict | int | None = None,
|
||||
**spaces_kwargs: Space
|
||||
):
|
||||
assert (spaces is None) or (
|
||||
not spaces_kwargs
|
||||
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||
@@ -57,10 +65,10 @@ class Dict(Space, Mapping):
|
||||
space, Space
|
||||
), "Values of the dict should be instances of gym.Space"
|
||||
super().__init__(
|
||||
None, None, seed
|
||||
None, None, seed # type: ignore
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: dict | int | None = None) -> list:
|
||||
seeds = []
|
||||
if isinstance(seed, dict):
|
||||
for key, seed_key in zip(self.spaces, seed):
|
||||
@@ -97,10 +105,10 @@ class Dict(Space, Mapping):
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(self):
|
||||
def sample(self) -> dict:
|
||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||
|
||||
def contains(self, x):
|
||||
def contains(self, x) -> bool:
|
||||
if not isinstance(x, dict) or len(x) != len(self.spaces):
|
||||
return False
|
||||
for k, space in self.spaces.items():
|
||||
@@ -119,29 +127,30 @@ class Dict(Space, Mapping):
|
||||
def __iter__(self):
|
||||
yield from self.spaces
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.spaces)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"Dict("
|
||||
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
def to_jsonable(self, sample_n: list) -> dict:
|
||||
# serialize as dict-repr of vectors
|
||||
return {
|
||||
key: space.to_jsonable([sample[key] for sample in sample_n])
|
||||
for key, space in self.spaces.items()
|
||||
}
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
dict_of_list = {}
|
||||
def from_jsonable(self, sample_n: dict[str, list]) -> list:
|
||||
dict_of_list: dict[str, list] = {}
|
||||
for key, space in self.spaces.items():
|
||||
dict_of_list[key] = space.from_jsonable(sample_n[key])
|
||||
ret = []
|
||||
for i, _ in enumerate(dict_of_list[key]):
|
||||
n_elements = len(next(iter(dict_of_list.values())))
|
||||
for i in range(n_elements):
|
||||
entry = {}
|
||||
for key, value in dict_of_list.items():
|
||||
entry[key] = value[i]
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from .space import Space
|
||||
|
||||
|
||||
class Discrete(Space):
|
||||
class Discrete(Space[int]):
|
||||
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
|
||||
|
||||
A start value can be optionally specified to shift the range
|
||||
@@ -15,33 +17,33 @@ class Discrete(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n, seed=None, start=0):
|
||||
def __init__(self, n: int, seed: Optional[int] = None, start: int = 0):
|
||||
assert n > 0, "n (counts) have to be positive"
|
||||
assert isinstance(start, (int, np.integer))
|
||||
self.n = int(n)
|
||||
self.start = int(start)
|
||||
super().__init__((), np.int64, seed)
|
||||
|
||||
def sample(self):
|
||||
def sample(self) -> int:
|
||||
return self.start + self.np_random.randint(self.n)
|
||||
|
||||
def contains(self, x):
|
||||
def contains(self, x) -> bool:
|
||||
if isinstance(x, int):
|
||||
as_int = x
|
||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
||||
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
|
||||
):
|
||||
as_int = int(x)
|
||||
as_int = int(x) # type: ignore
|
||||
else:
|
||||
return False
|
||||
return self.start <= as_int < self.start + self.n
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
if self.start != 0:
|
||||
return "Discrete(%d, start=%d)" % (self.n, self.start)
|
||||
return "Discrete(%d)" % self.n
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return (
|
||||
isinstance(other, Discrete)
|
||||
and self.n == other.n
|
||||
|
@@ -1,9 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union, Sequence
|
||||
import numpy as np
|
||||
from .space import Space
|
||||
|
||||
|
||||
class MultiBinary(Space):
|
||||
class MultiBinary(Space[np.ndarray]):
|
||||
"""
|
||||
An n-shape binary space.
|
||||
|
||||
@@ -27,7 +29,9 @@ class MultiBinary(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n, seed=None):
|
||||
def __init__(
|
||||
self, n: Union[np.ndarray, Sequence[int], int], seed: Optional[int] = None
|
||||
):
|
||||
if isinstance(n, (Sequence, np.ndarray)):
|
||||
self.n = input_n = tuple(int(i) for i in n)
|
||||
else:
|
||||
@@ -38,24 +42,29 @@ class MultiBinary(Space):
|
||||
|
||||
super().__init__(input_n, np.int8, seed)
|
||||
|
||||
def sample(self):
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
return self._shape # type: ignore
|
||||
|
||||
def sample(self) -> np.ndarray:
|
||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
def contains(self, x) -> bool:
|
||||
if isinstance(x, Sequence):
|
||||
x = np.array(x) # Promote list to array for contains check
|
||||
if self.shape != x.shape:
|
||||
return False
|
||||
return ((x == 0) | (x == 1)).all()
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
def to_jsonable(self, sample_n) -> list:
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
def from_jsonable(self, sample_n) -> list:
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"MultiBinary({self.n})"
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return isinstance(other, MultiBinary) and self.n == other.n
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import numpy as np
|
||||
from gym import logger
|
||||
@@ -5,7 +7,7 @@ from .space import Space
|
||||
from .discrete import Discrete
|
||||
|
||||
|
||||
class MultiDiscrete(Space):
|
||||
class MultiDiscrete(Space[np.ndarray]):
|
||||
"""
|
||||
- The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each
|
||||
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
|
||||
@@ -26,7 +28,7 @@ class MultiDiscrete(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, nvec, dtype=np.int64, seed=None):
|
||||
def __init__(self, nvec: list[int], dtype=np.int64, seed=None):
|
||||
"""
|
||||
nvec: vector of counts of each categorical variable
|
||||
"""
|
||||
@@ -35,15 +37,20 @@ class MultiDiscrete(Space):
|
||||
|
||||
super().__init__(self.nvec.shape, dtype, seed)
|
||||
|
||||
def sample(self):
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
return self._shape # type: ignore
|
||||
|
||||
def sample(self) -> np.ndarray:
|
||||
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
def contains(self, x) -> bool:
|
||||
if isinstance(x, Sequence):
|
||||
x = np.array(x) # Promote list to array for contains check
|
||||
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
|
||||
# is within correct bounds for space dtype (even though x does not have to be unsigned)
|
||||
return x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all()
|
||||
return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all())
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
return [sample.tolist() for sample in sample_n]
|
||||
|
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Generic,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@@ -32,8 +33,13 @@ class Space(Generic[T_cov]):
|
||||
not handle custom spaces properly. Use custom spaces with care.
|
||||
"""
|
||||
|
||||
def __init__(self, shape: Optional[Sequence[int]] = None, dtype=None, seed=None):
|
||||
import numpy as np # takes about 300-400ms to import, so we load lazily
|
||||
def __init__(
|
||||
self,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: Optional[Type | str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
import numpy as np # noqa ## takes about 300-400ms to import, so we load lazily
|
||||
|
||||
self._shape = None if shape is None else tuple(shape)
|
||||
self.dtype = None if dtype is None else np.dtype(dtype)
|
||||
@@ -42,7 +48,7 @@ class Space(Generic[T_cov]):
|
||||
self.seed(seed)
|
||||
|
||||
@property
|
||||
def np_random(self) -> np.random.RandomState:
|
||||
def np_random(self) -> seeding.RandomNumberGenerator:
|
||||
"""Lazily seed the rng since this is expensive and only needed if
|
||||
sampling from this space.
|
||||
"""
|
||||
@@ -52,7 +58,7 @@ class Space(Generic[T_cov]):
|
||||
return self._np_random # type: ignore ## self.seed() call guarantees right type.
|
||||
|
||||
@property
|
||||
def shape(self) -> Optional[Tuple[int, ...]]:
|
||||
def shape(self) -> Optional[tuple[int, ...]]:
|
||||
"""Return the shape of the space as an immutable property"""
|
||||
return self._shape
|
||||
|
||||
@@ -61,7 +67,7 @@ class Space(Generic[T_cov]):
|
||||
uniform or non-uniform sampling based on boundedness of space."""
|
||||
raise NotImplementedError
|
||||
|
||||
def seed(self, seed: Optional[int] = None):
|
||||
def seed(self, seed: Optional[int] = None) -> list:
|
||||
"""Seed the PRNG of this space."""
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
@@ -76,7 +82,7 @@ class Space(Generic[T_cov]):
|
||||
def __contains__(self, x) -> bool:
|
||||
return self.contains(x)
|
||||
|
||||
def __setstate__(self, state: Union[Iterable, Mapping]):
|
||||
def __setstate__(self, state: Iterable | Mapping):
|
||||
# Don't mutate the original state
|
||||
state = dict(state)
|
||||
|
||||
@@ -95,12 +101,12 @@ class Space(Generic[T_cov]):
|
||||
# Update our state
|
||||
self.__dict__.update(state)
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# By default, assume identity is JSONable
|
||||
return sample_n
|
||||
return list(sample_n)
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
def from_jsonable(self, sample_n: list) -> list[T_cov]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
# By default, assume identity is JSONable
|
||||
return sample_n
|
||||
|
@@ -1,8 +1,9 @@
|
||||
from typing import Iterable, List, Optional, Union
|
||||
import numpy as np
|
||||
from .space import Space
|
||||
|
||||
|
||||
class Tuple(Space):
|
||||
class Tuple(Space[tuple]):
|
||||
"""
|
||||
A tuple (i.e., product) of simpler spaces
|
||||
|
||||
@@ -10,16 +11,18 @@ class Tuple(Space):
|
||||
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
||||
"""
|
||||
|
||||
def __init__(self, spaces, seed=None):
|
||||
def __init__(
|
||||
self, spaces: Iterable[Space], seed: Optional[Union[int, List[int]]] = None
|
||||
):
|
||||
spaces = tuple(spaces)
|
||||
self.spaces = spaces
|
||||
for space in spaces:
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), "Elements of the tuple must be instances of gym.Space"
|
||||
super().__init__(None, None, seed)
|
||||
super().__init__(None, None, seed) # type: ignore
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
|
||||
seeds = []
|
||||
|
||||
if isinstance(seed, list):
|
||||
@@ -50,10 +53,10 @@ class Tuple(Space):
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(self):
|
||||
def sample(self) -> tuple:
|
||||
return tuple(space.sample() for space in self.spaces)
|
||||
|
||||
def contains(self, x):
|
||||
def contains(self, x) -> bool:
|
||||
if isinstance(x, (list, np.ndarray)):
|
||||
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
||||
return (
|
||||
@@ -62,17 +65,17 @@ class Tuple(Space):
|
||||
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
def to_jsonable(self, sample_n) -> list:
|
||||
# serialize as list-repr of tuple of vectors
|
||||
return [
|
||||
space.to_jsonable([sample[i] for sample in sample_n])
|
||||
for i, space in enumerate(self.spaces)
|
||||
]
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
def from_jsonable(self, sample_n) -> list:
|
||||
return [
|
||||
sample
|
||||
for sample in zip(
|
||||
@@ -83,11 +86,11 @@ class Tuple(Space):
|
||||
)
|
||||
]
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: int) -> Space:
|
||||
return self.spaces[index]
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.spaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return isinstance(other, Tuple) and self.spaces == other.spaces
|
||||
|
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import singledispatch, reduce
|
||||
from typing import TypeVar, Union
|
||||
import numpy as np
|
||||
import operator as op
|
||||
|
||||
@@ -9,10 +12,11 @@ from gym.spaces import MultiDiscrete
|
||||
from gym.spaces import MultiBinary
|
||||
from gym.spaces import Tuple
|
||||
from gym.spaces import Dict
|
||||
from gym.spaces import Space
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatdim(space):
|
||||
def flatdim(space: Space) -> int:
|
||||
"""Return the number of dimensions a flattened equivalent of this space
|
||||
would have.
|
||||
|
||||
@@ -24,32 +28,35 @@ def flatdim(space):
|
||||
|
||||
@flatdim.register(Box)
|
||||
@flatdim.register(MultiBinary)
|
||||
def _flatdim_box_multibinary(space):
|
||||
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
|
||||
return reduce(op.mul, space.shape, 1)
|
||||
|
||||
|
||||
@flatdim.register(Discrete)
|
||||
def _flatdim_discrete(space):
|
||||
def _flatdim_discrete(space: Discrete) -> int:
|
||||
return int(space.n)
|
||||
|
||||
|
||||
@flatdim.register(MultiDiscrete)
|
||||
def _flatdim_multidiscrete(space):
|
||||
def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
|
||||
return int(np.sum(space.nvec))
|
||||
|
||||
|
||||
@flatdim.register(Tuple)
|
||||
def _flatdim_tuple(space):
|
||||
def _flatdim_tuple(space: Tuple) -> int:
|
||||
return sum(flatdim(s) for s in space.spaces)
|
||||
|
||||
|
||||
@flatdim.register(Dict)
|
||||
def _flatdim_dict(space):
|
||||
def _flatdim_dict(space: Dict) -> int:
|
||||
return sum(flatdim(s) for s in space.spaces.values())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten(space, x):
|
||||
def flatten(space: Space[T], x: T) -> np.ndarray:
|
||||
"""Flatten a data point from a space.
|
||||
|
||||
This is useful when e.g. points from spaces must be passed to a neural
|
||||
@@ -64,19 +71,19 @@ def flatten(space, x):
|
||||
|
||||
@flatten.register(Box)
|
||||
@flatten.register(MultiBinary)
|
||||
def _flatten_box_multibinary(space, x):
|
||||
def _flatten_box_multibinary(space, x) -> np.ndarray:
|
||||
return np.asarray(x, dtype=space.dtype).flatten()
|
||||
|
||||
|
||||
@flatten.register(Discrete)
|
||||
def _flatten_discrete(space, x):
|
||||
def _flatten_discrete(space, x) -> np.ndarray:
|
||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||
onehot[x] = 1
|
||||
return onehot
|
||||
|
||||
|
||||
@flatten.register(MultiDiscrete)
|
||||
def _flatten_multidiscrete(space, x):
|
||||
def _flatten_multidiscrete(space, x) -> np.ndarray:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
@@ -86,17 +93,17 @@ def _flatten_multidiscrete(space, x):
|
||||
|
||||
|
||||
@flatten.register(Tuple)
|
||||
def _flatten_tuple(space, x):
|
||||
def _flatten_tuple(space, x) -> np.ndarray:
|
||||
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
||||
|
||||
|
||||
@flatten.register(Dict)
|
||||
def _flatten_dict(space, x):
|
||||
def _flatten_dict(space, x) -> np.ndarray:
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
|
||||
|
||||
@singledispatch
|
||||
def unflatten(space, x):
|
||||
def unflatten(space: Space[T], x: np.ndarray) -> T:
|
||||
"""Unflatten a data point from a space.
|
||||
|
||||
This reverses the transformation applied by ``flatten()``. You must ensure
|
||||
@@ -111,17 +118,17 @@ def unflatten(space, x):
|
||||
|
||||
@unflatten.register(Box)
|
||||
@unflatten.register(MultiBinary)
|
||||
def _unflatten_box_multibinary(space, x):
|
||||
def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.ndarray:
|
||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||
|
||||
|
||||
@unflatten.register(Discrete)
|
||||
def _unflatten_discrete(space, x):
|
||||
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
||||
return int(np.nonzero(x)[0][0])
|
||||
|
||||
|
||||
@unflatten.register(MultiDiscrete)
|
||||
def _unflatten_multidiscrete(space, x):
|
||||
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
@@ -130,7 +137,7 @@ def _unflatten_multidiscrete(space, x):
|
||||
|
||||
|
||||
@unflatten.register(Tuple)
|
||||
def _unflatten_tuple(space, x):
|
||||
def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple:
|
||||
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
|
||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||
return tuple(
|
||||
@@ -139,7 +146,7 @@ def _unflatten_tuple(space, x):
|
||||
|
||||
|
||||
@unflatten.register(Dict)
|
||||
def _unflatten_dict(space, x):
|
||||
def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
|
||||
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
|
||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||
return OrderedDict(
|
||||
@@ -151,7 +158,7 @@ def _unflatten_dict(space, x):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space):
|
||||
def flatten_space(space: Space) -> Box:
|
||||
"""Flatten a space into a single ``Box``.
|
||||
|
||||
This is equivalent to ``flatten()``, but operates on the space itself. The
|
||||
@@ -193,32 +200,32 @@ def flatten_space(space):
|
||||
|
||||
|
||||
@flatten_space.register(Box)
|
||||
def _flatten_space_box(space):
|
||||
def _flatten_space_box(space: Box) -> Box:
|
||||
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
|
||||
|
||||
|
||||
@flatten_space.register(Discrete)
|
||||
@flatten_space.register(MultiBinary)
|
||||
@flatten_space.register(MultiDiscrete)
|
||||
def _flatten_space_binary(space):
|
||||
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
|
||||
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
|
||||
|
||||
|
||||
@flatten_space.register(Tuple)
|
||||
def _flatten_space_tuple(space):
|
||||
space = [flatten_space(s) for s in space.spaces]
|
||||
def _flatten_space_tuple(space: Tuple) -> Box:
|
||||
space_list = [flatten_space(s) for s in space.spaces]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space]),
|
||||
high=np.concatenate([s.high for s in space]),
|
||||
dtype=np.result_type(*[s.dtype for s in space]),
|
||||
low=np.concatenate([s.low for s in space_list]),
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Dict)
|
||||
def _flatten_space_dict(space):
|
||||
space = [flatten_space(s) for s in space.spaces.values()]
|
||||
def _flatten_space_dict(space: Dict) -> Box:
|
||||
space_list = [flatten_space(s) for s in space.spaces.values()]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space]),
|
||||
high=np.concatenate([s.high for s in space]),
|
||||
dtype=np.result_type(*[s.dtype for s in space]),
|
||||
low=np.concatenate([s.low for s in space_list]),
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import hashlib
|
||||
from typing import Optional, List, Union
|
||||
from typing import Optional, List, Tuple, Union, Any
|
||||
import os
|
||||
import struct
|
||||
|
||||
@@ -10,7 +10,7 @@ from gym import error
|
||||
from gym.logger import deprecation
|
||||
|
||||
|
||||
def np_random(seed: Optional[int] = None) -> tuple:
|
||||
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
|
||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
||||
|
||||
|
@@ -4,7 +4,7 @@ include = [
|
||||
"gym/version.py",
|
||||
"gym/logger.py",
|
||||
"gym/envs/registration.py",
|
||||
"gym/spaces/space.py",
|
||||
"gym/spaces/**.py",
|
||||
"gym/core.py",
|
||||
"gym/utils/seeding.py"
|
||||
]
|
||||
|
Reference in New Issue
Block a user