typing in gym.spaces (#2541)

* typing in spaces.Box and spaces.Discrete

* adds typing to dict and tuple spaces

* Typecheck all spaces

* Explicit regex to include all files under space folder

* Style: use native types and __future__ annotations

* Allow only specific strings for Box.is_bounded args

* Add typing to changes from #2517

* Remove Literal as it's not supported by py3.7

* Use more recent version of pyright

* Avoid name clash for type checker

* Revert "Avoid name clash for type checker"

This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1.

* Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852

* rebase and add typing for `_short_repr`
This commit is contained in:
Ilya Kamen
2022-01-24 23:22:11 +01:00
committed by GitHub
parent fcbff7de12
commit ad79b0ad0f
11 changed files with 199 additions and 146 deletions

View File

@@ -22,7 +22,7 @@ jobs:
python-version: [ "3.7"]
fail-fast: false
env:
PYRIGHT_VERSION: 1.1.183
PYRIGHT_VERSION: 1.1.204
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2

View File

@@ -1,10 +1,14 @@
from __future__ import annotations
from typing import Tuple, SupportsFloat, Union, Type, Optional, Sequence
import numpy as np
from .space import Space
from gym import logger
def _short_repr(arr):
def _short_repr(arr: np.ndarray) -> str:
"""Create a shortened string representation of a numpy array.
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
@@ -15,7 +19,7 @@ def _short_repr(arr):
return str(arr)
class Box(Space):
class Box(Space[np.ndarray]):
"""
A (possibly unbounded) box in R^n. Specifically, a Box represents the
Cartesian product of n closed intervals. Each interval has the form of one
@@ -33,66 +37,47 @@ class Box(Space):
"""
def __init__(self, low, high, shape=None, dtype=np.float32, seed=None):
def __init__(
self,
low: Union[SupportsFloat, np.ndarray],
high: Union[SupportsFloat, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: Type = np.float32,
seed: Optional[int] = None,
):
assert dtype is not None, "dtype must be explicitly provided. "
self.dtype = np.dtype(dtype)
# determine shape if it isn't provided directly
if shape is not None:
shape = tuple(shape)
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match provided shape"
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match provided shape"
elif not np.isscalar(low):
shape = low.shape
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match low.shape"
shape = low.shape # type: ignore
elif not np.isscalar(high):
shape = high.shape
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match high.shape"
shape = high.shape # type: ignore
else:
raise ValueError(
"shape must be provided or inferred from the shapes of low or high"
)
assert isinstance(shape, tuple)
# handle infinite bounds and broadcast at the same time if needed
if np.isscalar(low):
low = get_inf(dtype, "-") if np.isinf(low) else low
low = np.full(shape, low, dtype=dtype)
else:
if np.any(np.isinf(low)):
# create new array with dtype, but maintain old one to preserve np.inf
temp_low = low.astype(dtype)
temp_low[np.isinf(low)] = get_inf(dtype, "-")
low = temp_low
low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
high = _broadcast(high, dtype, shape, inf_sign="+")
if np.isscalar(high):
high = get_inf(dtype, "+") if np.isinf(high) else high
high = np.full(shape, high, dtype=dtype)
else:
if np.any(np.isinf(high)):
# create new array with dtype, but maintain old one to preserve np.inf
temp_high = high.astype(dtype)
temp_high[np.isinf(high)] = get_inf(dtype, "+")
high = temp_high
assert isinstance(low, np.ndarray)
assert low.shape == shape, "low.shape doesn't match provided shape"
assert isinstance(high, np.ndarray)
assert high.shape == shape, "high.shape doesn't match provided shape"
self._shape = shape
self.low = low
self.high = high
self._shape: Tuple[int, ...] = shape
low_precision = get_precision(self.low.dtype)
high_precision = get_precision(self.high.dtype)
low_precision = get_precision(low.dtype)
high_precision = get_precision(high.dtype)
dtype_precision = get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision:
if min(low_precision, high_precision) > dtype_precision: # type: ignore
logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
self.low = self.low.astype(self.dtype)
self.high = self.high.astype(self.dtype)
self.low = low.astype(self.dtype)
self.high = high.astype(self.dtype)
self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high)
@@ -103,9 +88,14 @@ class Box(Space):
super().__init__(self.shape, self.dtype, seed)
def is_bounded(self, manner="both"):
below = np.all(self.bounded_below)
above = np.all(self.bounded_above)
@property
def shape(self) -> Tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape
def is_bounded(self, manner: str = "both") -> bool:
below = bool(np.all(self.bounded_below))
above = bool(np.all(self.bounded_above))
if manner == "both":
return below and above
elif manner == "below":
@@ -115,7 +105,7 @@ class Box(Space):
else:
raise ValueError("manner is not in {'below', 'above', 'both'}")
def sample(self):
def sample(self) -> np.ndarray:
"""
Generates a single random sample inside of the Box.
@@ -158,12 +148,12 @@ class Box(Space):
return sample.astype(self.dtype)
def contains(self, x):
def contains(self, x) -> bool:
if not isinstance(x, np.ndarray):
logger.warn("Casting input x to numpy array.")
x = np.asarray(x, dtype=self.dtype)
return (
return bool(
np.can_cast(x.dtype, self.dtype)
and x.shape == self.shape
and np.all(x >= self.low)
@@ -173,13 +163,13 @@ class Box(Space):
def to_jsonable(self, sample_n):
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n):
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> list[np.ndarray]:
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
def __repr__(self) -> str:
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
def __eq__(self, other):
def __eq__(self, other) -> bool:
return (
isinstance(other, Box)
and (self.shape == other.shape)
@@ -188,7 +178,7 @@ class Box(Space):
)
def get_inf(dtype, sign):
def get_inf(dtype, sign: str) -> SupportsFloat:
"""Returns an infinite that doesn't break things.
`dtype` must be an `np.dtype`
`bound` must be either `min` or `max`
@@ -211,8 +201,28 @@ def get_inf(dtype, sign):
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
def get_precision(dtype):
def get_precision(dtype) -> SupportsFloat:
if np.issubdtype(dtype, np.floating):
return np.finfo(dtype).precision
else:
return np.inf
def _broadcast(
value: Union[SupportsFloat, np.ndarray],
dtype,
shape: tuple[int, ...],
inf_sign: str,
) -> np.ndarray:
"""handle infinite bounds and broadcast at the same time if needed"""
if np.isscalar(value):
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
value = np.full(shape, value, dtype=dtype)
else:
assert isinstance(value, np.ndarray)
if np.any(np.isinf(value)):
# create new array with dtype, but maintain old one to preserve np.inf
temp = value.astype(dtype)
temp[np.isinf(value)] = get_inf(dtype, inf_sign)
value = temp
return value

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from typing import Dict as TypingDict
import numpy as np
from .space import Space
class Dict(Space, Mapping):
class Dict(Space[TypingDict[str, Space]], Mapping):
"""
A dictionary of simpler spaces.
@@ -34,7 +37,12 @@ class Dict(Space, Mapping):
})
"""
def __init__(self, spaces=None, seed=None, **spaces_kwargs):
def __init__(
self,
spaces: dict[str, Space] | None = None,
seed: dict | int | None = None,
**spaces_kwargs: Space
):
assert (spaces is None) or (
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
@@ -57,10 +65,10 @@ class Dict(Space, Mapping):
space, Space
), "Values of the dict should be instances of gym.Space"
super().__init__(
None, None, seed
None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling
def seed(self, seed=None):
def seed(self, seed: dict | int | None = None) -> list:
seeds = []
if isinstance(seed, dict):
for key, seed_key in zip(self.spaces, seed):
@@ -97,10 +105,10 @@ class Dict(Space, Mapping):
return seeds
def sample(self):
def sample(self) -> dict:
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
def contains(self, x):
def contains(self, x) -> bool:
if not isinstance(x, dict) or len(x) != len(self.spaces):
return False
for k, space in self.spaces.items():
@@ -119,29 +127,30 @@ class Dict(Space, Mapping):
def __iter__(self):
yield from self.spaces
def __len__(self):
def __len__(self) -> int:
return len(self.spaces)
def __repr__(self):
def __repr__(self) -> str:
return (
"Dict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")"
)
def to_jsonable(self, sample_n):
def to_jsonable(self, sample_n: list) -> dict:
# serialize as dict-repr of vectors
return {
key: space.to_jsonable([sample[key] for sample in sample_n])
for key, space in self.spaces.items()
}
def from_jsonable(self, sample_n):
dict_of_list = {}
def from_jsonable(self, sample_n: dict[str, list]) -> list:
dict_of_list: dict[str, list] = {}
for key, space in self.spaces.items():
dict_of_list[key] = space.from_jsonable(sample_n[key])
ret = []
for i, _ in enumerate(dict_of_list[key]):
n_elements = len(next(iter(dict_of_list.values())))
for i in range(n_elements):
entry = {}
for key, value in dict_of_list.items():
entry[key] = value[i]

View File

@@ -1,8 +1,10 @@
from typing import Optional
import numpy as np
from .space import Space
class Discrete(Space):
class Discrete(Space[int]):
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
A start value can be optionally specified to shift the range
@@ -15,33 +17,33 @@ class Discrete(Space):
"""
def __init__(self, n, seed=None, start=0):
def __init__(self, n: int, seed: Optional[int] = None, start: int = 0):
assert n > 0, "n (counts) have to be positive"
assert isinstance(start, (int, np.integer))
self.n = int(n)
self.start = int(start)
super().__init__((), np.int64, seed)
def sample(self):
def sample(self) -> int:
return self.start + self.np_random.randint(self.n)
def contains(self, x):
def contains(self, x) -> bool:
if isinstance(x, int):
as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
):
as_int = int(x)
as_int = int(x) # type: ignore
else:
return False
return self.start <= as_int < self.start + self.n
def __repr__(self):
def __repr__(self) -> str:
if self.start != 0:
return "Discrete(%d, start=%d)" % (self.n, self.start)
return "Discrete(%d)" % self.n
def __eq__(self, other):
def __eq__(self, other) -> bool:
return (
isinstance(other, Discrete)
and self.n == other.n

View File

@@ -1,9 +1,11 @@
from collections.abc import Sequence
from __future__ import annotations
from typing import Optional, Union, Sequence
import numpy as np
from .space import Space
class MultiBinary(Space):
class MultiBinary(Space[np.ndarray]):
"""
An n-shape binary space.
@@ -27,7 +29,9 @@ class MultiBinary(Space):
"""
def __init__(self, n, seed=None):
def __init__(
self, n: Union[np.ndarray, Sequence[int], int], seed: Optional[int] = None
):
if isinstance(n, (Sequence, np.ndarray)):
self.n = input_n = tuple(int(i) for i in n)
else:
@@ -38,24 +42,29 @@ class MultiBinary(Space):
super().__init__(input_n, np.int8, seed)
def sample(self):
@property
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
def sample(self) -> np.ndarray:
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x):
def contains(self, x) -> bool:
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape:
return False
return ((x == 0) | (x == 1)).all()
def to_jsonable(self, sample_n):
def to_jsonable(self, sample_n) -> list:
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n):
def from_jsonable(self, sample_n) -> list:
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
def __repr__(self) -> str:
return f"MultiBinary({self.n})"
def __eq__(self, other):
def __eq__(self, other) -> bool:
return isinstance(other, MultiBinary) and self.n == other.n

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
import numpy as np
from gym import logger
@@ -5,7 +7,7 @@ from .space import Space
from .discrete import Discrete
class MultiDiscrete(Space):
class MultiDiscrete(Space[np.ndarray]):
"""
- The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
@@ -26,7 +28,7 @@ class MultiDiscrete(Space):
"""
def __init__(self, nvec, dtype=np.int64, seed=None):
def __init__(self, nvec: list[int], dtype=np.int64, seed=None):
"""
nvec: vector of counts of each categorical variable
"""
@@ -35,15 +37,20 @@ class MultiDiscrete(Space):
super().__init__(self.nvec.shape, dtype, seed)
def sample(self):
@property
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
def sample(self) -> np.ndarray:
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x):
def contains(self, x) -> bool:
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
# is within correct bounds for space dtype (even though x does not have to be unsigned)
return x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all()
return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all())
def to_jsonable(self, sample_n):
return [sample.tolist() for sample in sample_n]

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
from typing import (
TypeVar,
Generic,
Optional,
Sequence,
Union,
Iterable,
Mapping,
Tuple,
Type,
)
import numpy as np
@@ -32,8 +33,13 @@ class Space(Generic[T_cov]):
not handle custom spaces properly. Use custom spaces with care.
"""
def __init__(self, shape: Optional[Sequence[int]] = None, dtype=None, seed=None):
import numpy as np # takes about 300-400ms to import, so we load lazily
def __init__(
self,
shape: Optional[Sequence[int]] = None,
dtype: Optional[Type | str] = None,
seed: Optional[int] = None,
):
import numpy as np # noqa ## takes about 300-400ms to import, so we load lazily
self._shape = None if shape is None else tuple(shape)
self.dtype = None if dtype is None else np.dtype(dtype)
@@ -42,7 +48,7 @@ class Space(Generic[T_cov]):
self.seed(seed)
@property
def np_random(self) -> np.random.RandomState:
def np_random(self) -> seeding.RandomNumberGenerator:
"""Lazily seed the rng since this is expensive and only needed if
sampling from this space.
"""
@@ -52,7 +58,7 @@ class Space(Generic[T_cov]):
return self._np_random # type: ignore ## self.seed() call guarantees right type.
@property
def shape(self) -> Optional[Tuple[int, ...]]:
def shape(self) -> Optional[tuple[int, ...]]:
"""Return the shape of the space as an immutable property"""
return self._shape
@@ -61,7 +67,7 @@ class Space(Generic[T_cov]):
uniform or non-uniform sampling based on boundedness of space."""
raise NotImplementedError
def seed(self, seed: Optional[int] = None):
def seed(self, seed: Optional[int] = None) -> list:
"""Seed the PRNG of this space."""
self._np_random, seed = seeding.np_random(seed)
return [seed]
@@ -76,7 +82,7 @@ class Space(Generic[T_cov]):
def __contains__(self, x) -> bool:
return self.contains(x)
def __setstate__(self, state: Union[Iterable, Mapping]):
def __setstate__(self, state: Iterable | Mapping):
# Don't mutate the original state
state = dict(state)
@@ -95,12 +101,12 @@ class Space(Generic[T_cov]):
# Update our state
self.__dict__.update(state)
def to_jsonable(self, sample_n):
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
# By default, assume identity is JSONable
return sample_n
return list(sample_n)
def from_jsonable(self, sample_n):
def from_jsonable(self, sample_n: list) -> list[T_cov]:
"""Convert a JSONable data type to a batch of samples from this space."""
# By default, assume identity is JSONable
return sample_n

View File

@@ -1,8 +1,9 @@
from typing import Iterable, List, Optional, Union
import numpy as np
from .space import Space
class Tuple(Space):
class Tuple(Space[tuple]):
"""
A tuple (i.e., product) of simpler spaces
@@ -10,16 +11,18 @@ class Tuple(Space):
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
"""
def __init__(self, spaces, seed=None):
def __init__(
self, spaces: Iterable[Space], seed: Optional[Union[int, List[int]]] = None
):
spaces = tuple(spaces)
self.spaces = spaces
for space in spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed)
super().__init__(None, None, seed) # type: ignore
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
seeds = []
if isinstance(seed, list):
@@ -50,10 +53,10 @@ class Tuple(Space):
return seeds
def sample(self):
def sample(self) -> tuple:
return tuple(space.sample() for space in self.spaces)
def contains(self, x):
def contains(self, x) -> bool:
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
return (
@@ -62,17 +65,17 @@ class Tuple(Space):
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
)
def __repr__(self):
def __repr__(self) -> str:
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n):
def to_jsonable(self, sample_n) -> list:
# serialize as list-repr of tuple of vectors
return [
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n):
def from_jsonable(self, sample_n) -> list:
return [
sample
for sample in zip(
@@ -83,11 +86,11 @@ class Tuple(Space):
)
]
def __getitem__(self, index):
def __getitem__(self, index: int) -> Space:
return self.spaces[index]
def __len__(self):
def __len__(self) -> int:
return len(self.spaces)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return isinstance(other, Tuple) and self.spaces == other.spaces

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
from collections import OrderedDict
from functools import singledispatch, reduce
from typing import TypeVar, Union
import numpy as np
import operator as op
@@ -9,10 +12,11 @@ from gym.spaces import MultiDiscrete
from gym.spaces import MultiBinary
from gym.spaces import Tuple
from gym.spaces import Dict
from gym.spaces import Space
@singledispatch
def flatdim(space):
def flatdim(space: Space) -> int:
"""Return the number of dimensions a flattened equivalent of this space
would have.
@@ -24,32 +28,35 @@ def flatdim(space):
@flatdim.register(Box)
@flatdim.register(MultiBinary)
def _flatdim_box_multibinary(space):
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
return reduce(op.mul, space.shape, 1)
@flatdim.register(Discrete)
def _flatdim_discrete(space):
def _flatdim_discrete(space: Discrete) -> int:
return int(space.n)
@flatdim.register(MultiDiscrete)
def _flatdim_multidiscrete(space):
def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
return int(np.sum(space.nvec))
@flatdim.register(Tuple)
def _flatdim_tuple(space):
def _flatdim_tuple(space: Tuple) -> int:
return sum(flatdim(s) for s in space.spaces)
@flatdim.register(Dict)
def _flatdim_dict(space):
def _flatdim_dict(space: Dict) -> int:
return sum(flatdim(s) for s in space.spaces.values())
T = TypeVar("T")
@singledispatch
def flatten(space, x):
def flatten(space: Space[T], x: T) -> np.ndarray:
"""Flatten a data point from a space.
This is useful when e.g. points from spaces must be passed to a neural
@@ -64,19 +71,19 @@ def flatten(space, x):
@flatten.register(Box)
@flatten.register(MultiBinary)
def _flatten_box_multibinary(space, x):
def _flatten_box_multibinary(space, x) -> np.ndarray:
return np.asarray(x, dtype=space.dtype).flatten()
@flatten.register(Discrete)
def _flatten_discrete(space, x):
def _flatten_discrete(space, x) -> np.ndarray:
onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x] = 1
return onehot
@flatten.register(MultiDiscrete)
def _flatten_multidiscrete(space, x):
def _flatten_multidiscrete(space, x) -> np.ndarray:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten())
@@ -86,17 +93,17 @@ def _flatten_multidiscrete(space, x):
@flatten.register(Tuple)
def _flatten_tuple(space, x):
def _flatten_tuple(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
@flatten.register(Dict)
def _flatten_dict(space, x):
def _flatten_dict(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
@singledispatch
def unflatten(space, x):
def unflatten(space: Space[T], x: np.ndarray) -> T:
"""Unflatten a data point from a space.
This reverses the transformation applied by ``flatten()``. You must ensure
@@ -111,17 +118,17 @@ def unflatten(space, x):
@unflatten.register(Box)
@unflatten.register(MultiBinary)
def _unflatten_box_multibinary(space, x):
def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.ndarray:
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
@unflatten.register(Discrete)
def _unflatten_discrete(space, x):
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
return int(np.nonzero(x)[0][0])
@unflatten.register(MultiDiscrete)
def _unflatten_multidiscrete(space, x):
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten())
@@ -130,7 +137,7 @@ def _unflatten_multidiscrete(space, x):
@unflatten.register(Tuple)
def _unflatten_tuple(space, x):
def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple:
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1]))
return tuple(
@@ -139,7 +146,7 @@ def _unflatten_tuple(space, x):
@unflatten.register(Dict)
def _unflatten_dict(space, x):
def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1]))
return OrderedDict(
@@ -151,7 +158,7 @@ def _unflatten_dict(space, x):
@singledispatch
def flatten_space(space):
def flatten_space(space: Space) -> Box:
"""Flatten a space into a single ``Box``.
This is equivalent to ``flatten()``, but operates on the space itself. The
@@ -193,32 +200,32 @@ def flatten_space(space):
@flatten_space.register(Box)
def _flatten_space_box(space):
def _flatten_space_box(space: Box) -> Box:
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
@flatten_space.register(Discrete)
@flatten_space.register(MultiBinary)
@flatten_space.register(MultiDiscrete)
def _flatten_space_binary(space):
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
@flatten_space.register(Tuple)
def _flatten_space_tuple(space):
space = [flatten_space(s) for s in space.spaces]
def _flatten_space_tuple(space: Tuple) -> Box:
space_list = [flatten_space(s) for s in space.spaces]
return Box(
low=np.concatenate([s.low for s in space]),
high=np.concatenate([s.high for s in space]),
dtype=np.result_type(*[s.dtype for s in space]),
low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)
@flatten_space.register(Dict)
def _flatten_space_dict(space):
space = [flatten_space(s) for s in space.spaces.values()]
def _flatten_space_dict(space: Dict) -> Box:
space_list = [flatten_space(s) for s in space.spaces.values()]
return Box(
low=np.concatenate([s.low for s in space]),
high=np.concatenate([s.high for s in space]),
dtype=np.result_type(*[s.dtype for s in space]),
low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)

View File

@@ -1,5 +1,5 @@
import hashlib
from typing import Optional, List, Union
from typing import Optional, List, Tuple, Union, Any
import os
import struct
@@ -10,7 +10,7 @@ from gym import error
from gym.logger import deprecation
def np_random(seed: Optional[int] = None) -> tuple:
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")

View File

@@ -4,7 +4,7 @@ include = [
"gym/version.py",
"gym/logger.py",
"gym/envs/registration.py",
"gym/spaces/space.py",
"gym/spaces/**.py",
"gym/core.py",
"gym/utils/seeding.py"
]