From ad79b0ad0f2518ba9c4b4f2d58a72e9a3c49d85f Mon Sep 17 00:00:00 2001 From: Ilya Kamen Date: Mon, 24 Jan 2022 23:22:11 +0100 Subject: [PATCH] typing in gym.spaces (#2541) * typing in spaces.Box and spaces.Discrete * adds typing to dict and tuple spaces * Typecheck all spaces * Explicit regex to include all files under space folder * Style: use native types and __future__ annotations * Allow only specific strings for Box.is_bounded args * Add typing to changes from #2517 * Remove Literal as it's not supported by py3.7 * Use more recent version of pyright * Avoid name clash for type checker * Revert "Avoid name clash for type checker" This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1. * Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852 * rebase and add typing for `_short_repr` --- .github/workflows/lint_python.yml | 2 +- gym/spaces/box.py | 120 ++++++++++++++++-------------- gym/spaces/dict.py | 33 +++++--- gym/spaces/discrete.py | 16 ++-- gym/spaces/multi_binary.py | 27 ++++--- gym/spaces/multi_discrete.py | 17 +++-- gym/spaces/space.py | 28 ++++--- gym/spaces/tuple.py | 27 ++++--- gym/spaces/utils.py | 69 +++++++++-------- gym/utils/seeding.py | 4 +- pyproject.toml | 2 +- 11 files changed, 199 insertions(+), 146 deletions(-) diff --git a/.github/workflows/lint_python.yml b/.github/workflows/lint_python.yml index abd26677f..60eee2c07 100644 --- a/.github/workflows/lint_python.yml +++ b/.github/workflows/lint_python.yml @@ -22,7 +22,7 @@ jobs: python-version: [ "3.7"] fail-fast: false env: - PYRIGHT_VERSION: 1.1.183 + PYRIGHT_VERSION: 1.1.204 steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 diff --git a/gym/spaces/box.py b/gym/spaces/box.py index b18a99580..42b3be219 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -1,10 +1,14 @@ +from __future__ import annotations + +from typing import Tuple, SupportsFloat, Union, Type, Optional, Sequence + import numpy as np from .space import Space from gym import logger -def _short_repr(arr): +def _short_repr(arr: np.ndarray) -> str: """Create a shortened string representation of a numpy array. If arr is a multiple of the all-ones vector, return a string representation of the multiplier. @@ -15,7 +19,7 @@ def _short_repr(arr): return str(arr) -class Box(Space): +class Box(Space[np.ndarray]): """ A (possibly unbounded) box in R^n. Specifically, a Box represents the Cartesian product of n closed intervals. Each interval has the form of one @@ -33,66 +37,47 @@ class Box(Space): """ - def __init__(self, low, high, shape=None, dtype=np.float32, seed=None): + def __init__( + self, + low: Union[SupportsFloat, np.ndarray], + high: Union[SupportsFloat, np.ndarray], + shape: Optional[Sequence[int]] = None, + dtype: Type = np.float32, + seed: Optional[int] = None, + ): assert dtype is not None, "dtype must be explicitly provided. " self.dtype = np.dtype(dtype) # determine shape if it isn't provided directly if shape is not None: shape = tuple(shape) - assert ( - np.isscalar(low) or low.shape == shape - ), "low.shape doesn't match provided shape" - assert ( - np.isscalar(high) or high.shape == shape - ), "high.shape doesn't match provided shape" elif not np.isscalar(low): - shape = low.shape - assert ( - np.isscalar(high) or high.shape == shape - ), "high.shape doesn't match low.shape" + shape = low.shape # type: ignore elif not np.isscalar(high): - shape = high.shape - assert ( - np.isscalar(low) or low.shape == shape - ), "low.shape doesn't match high.shape" + shape = high.shape # type: ignore else: raise ValueError( "shape must be provided or inferred from the shapes of low or high" ) + assert isinstance(shape, tuple) - # handle infinite bounds and broadcast at the same time if needed - if np.isscalar(low): - low = get_inf(dtype, "-") if np.isinf(low) else low - low = np.full(shape, low, dtype=dtype) - else: - if np.any(np.isinf(low)): - # create new array with dtype, but maintain old one to preserve np.inf - temp_low = low.astype(dtype) - temp_low[np.isinf(low)] = get_inf(dtype, "-") - low = temp_low + low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore + high = _broadcast(high, dtype, shape, inf_sign="+") - if np.isscalar(high): - high = get_inf(dtype, "+") if np.isinf(high) else high - high = np.full(shape, high, dtype=dtype) - else: - if np.any(np.isinf(high)): - # create new array with dtype, but maintain old one to preserve np.inf - temp_high = high.astype(dtype) - temp_high[np.isinf(high)] = get_inf(dtype, "+") - high = temp_high + assert isinstance(low, np.ndarray) + assert low.shape == shape, "low.shape doesn't match provided shape" + assert isinstance(high, np.ndarray) + assert high.shape == shape, "high.shape doesn't match provided shape" - self._shape = shape - self.low = low - self.high = high + self._shape: Tuple[int, ...] = shape - low_precision = get_precision(self.low.dtype) - high_precision = get_precision(self.high.dtype) + low_precision = get_precision(low.dtype) + high_precision = get_precision(high.dtype) dtype_precision = get_precision(self.dtype) - if min(low_precision, high_precision) > dtype_precision: + if min(low_precision, high_precision) > dtype_precision: # type: ignore logger.warn(f"Box bound precision lowered by casting to {self.dtype}") - self.low = self.low.astype(self.dtype) - self.high = self.high.astype(self.dtype) + self.low = low.astype(self.dtype) + self.high = high.astype(self.dtype) self.low_repr = _short_repr(self.low) self.high_repr = _short_repr(self.high) @@ -103,9 +88,14 @@ class Box(Space): super().__init__(self.shape, self.dtype, seed) - def is_bounded(self, manner="both"): - below = np.all(self.bounded_below) - above = np.all(self.bounded_above) + @property + def shape(self) -> Tuple[int, ...]: + """Has stricter type than gym.Space - never None.""" + return self._shape + + def is_bounded(self, manner: str = "both") -> bool: + below = bool(np.all(self.bounded_below)) + above = bool(np.all(self.bounded_above)) if manner == "both": return below and above elif manner == "below": @@ -115,7 +105,7 @@ class Box(Space): else: raise ValueError("manner is not in {'below', 'above', 'both'}") - def sample(self): + def sample(self) -> np.ndarray: """ Generates a single random sample inside of the Box. @@ -158,12 +148,12 @@ class Box(Space): return sample.astype(self.dtype) - def contains(self, x): + def contains(self, x) -> bool: if not isinstance(x, np.ndarray): logger.warn("Casting input x to numpy array.") x = np.asarray(x, dtype=self.dtype) - return ( + return bool( np.can_cast(x.dtype, self.dtype) and x.shape == self.shape and np.all(x >= self.low) @@ -173,13 +163,13 @@ class Box(Space): def to_jsonable(self, sample_n): return np.array(sample_n).tolist() - def from_jsonable(self, sample_n): + def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> list[np.ndarray]: return [np.asarray(sample) for sample in sample_n] - def __repr__(self): + def __repr__(self) -> str: return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})" - def __eq__(self, other): + def __eq__(self, other) -> bool: return ( isinstance(other, Box) and (self.shape == other.shape) @@ -188,7 +178,7 @@ class Box(Space): ) -def get_inf(dtype, sign): +def get_inf(dtype, sign: str) -> SupportsFloat: """Returns an infinite that doesn't break things. `dtype` must be an `np.dtype` `bound` must be either `min` or `max` @@ -211,8 +201,28 @@ def get_inf(dtype, sign): raise ValueError(f"Unknown dtype {dtype} for infinite bounds") -def get_precision(dtype): +def get_precision(dtype) -> SupportsFloat: if np.issubdtype(dtype, np.floating): return np.finfo(dtype).precision else: return np.inf + + +def _broadcast( + value: Union[SupportsFloat, np.ndarray], + dtype, + shape: tuple[int, ...], + inf_sign: str, +) -> np.ndarray: + """handle infinite bounds and broadcast at the same time if needed""" + if np.isscalar(value): + value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore + value = np.full(shape, value, dtype=dtype) + else: + assert isinstance(value, np.ndarray) + if np.any(np.isinf(value)): + # create new array with dtype, but maintain old one to preserve np.inf + temp = value.astype(dtype) + temp[np.isinf(value)] = get_inf(dtype, inf_sign) + value = temp + return value diff --git a/gym/spaces/dict.py b/gym/spaces/dict.py index 565f121f7..fe12b3575 100644 --- a/gym/spaces/dict.py +++ b/gym/spaces/dict.py @@ -1,10 +1,13 @@ +from __future__ import annotations + from collections import OrderedDict from collections.abc import Mapping, Sequence +from typing import Dict as TypingDict import numpy as np from .space import Space -class Dict(Space, Mapping): +class Dict(Space[TypingDict[str, Space]], Mapping): """ A dictionary of simpler spaces. @@ -34,7 +37,12 @@ class Dict(Space, Mapping): }) """ - def __init__(self, spaces=None, seed=None, **spaces_kwargs): + def __init__( + self, + spaces: dict[str, Space] | None = None, + seed: dict | int | None = None, + **spaces_kwargs: Space + ): assert (spaces is None) or ( not spaces_kwargs ), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" @@ -57,10 +65,10 @@ class Dict(Space, Mapping): space, Space ), "Values of the dict should be instances of gym.Space" super().__init__( - None, None, seed + None, None, seed # type: ignore ) # None for shape and dtype, since it'll require special handling - def seed(self, seed=None): + def seed(self, seed: dict | int | None = None) -> list: seeds = [] if isinstance(seed, dict): for key, seed_key in zip(self.spaces, seed): @@ -97,10 +105,10 @@ class Dict(Space, Mapping): return seeds - def sample(self): + def sample(self) -> dict: return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()]) - def contains(self, x): + def contains(self, x) -> bool: if not isinstance(x, dict) or len(x) != len(self.spaces): return False for k, space in self.spaces.items(): @@ -119,29 +127,30 @@ class Dict(Space, Mapping): def __iter__(self): yield from self.spaces - def __len__(self): + def __len__(self) -> int: return len(self.spaces) - def __repr__(self): + def __repr__(self) -> str: return ( "Dict(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")" ) - def to_jsonable(self, sample_n): + def to_jsonable(self, sample_n: list) -> dict: # serialize as dict-repr of vectors return { key: space.to_jsonable([sample[key] for sample in sample_n]) for key, space in self.spaces.items() } - def from_jsonable(self, sample_n): - dict_of_list = {} + def from_jsonable(self, sample_n: dict[str, list]) -> list: + dict_of_list: dict[str, list] = {} for key, space in self.spaces.items(): dict_of_list[key] = space.from_jsonable(sample_n[key]) ret = [] - for i, _ in enumerate(dict_of_list[key]): + n_elements = len(next(iter(dict_of_list.values()))) + for i in range(n_elements): entry = {} for key, value in dict_of_list.items(): entry[key] = value[i] diff --git a/gym/spaces/discrete.py b/gym/spaces/discrete.py index 8bb58d0be..08403f839 100644 --- a/gym/spaces/discrete.py +++ b/gym/spaces/discrete.py @@ -1,8 +1,10 @@ +from typing import Optional + import numpy as np from .space import Space -class Discrete(Space): +class Discrete(Space[int]): r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`. A start value can be optionally specified to shift the range @@ -15,33 +17,33 @@ class Discrete(Space): """ - def __init__(self, n, seed=None, start=0): + def __init__(self, n: int, seed: Optional[int] = None, start: int = 0): assert n > 0, "n (counts) have to be positive" assert isinstance(start, (int, np.integer)) self.n = int(n) self.start = int(start) super().__init__((), np.int64, seed) - def sample(self): + def sample(self) -> int: return self.start + self.np_random.randint(self.n) - def contains(self, x): + def contains(self, x) -> bool: if isinstance(x, int): as_int = x elif isinstance(x, (np.generic, np.ndarray)) and ( x.dtype.char in np.typecodes["AllInteger"] and x.shape == () ): - as_int = int(x) + as_int = int(x) # type: ignore else: return False return self.start <= as_int < self.start + self.n - def __repr__(self): + def __repr__(self) -> str: if self.start != 0: return "Discrete(%d, start=%d)" % (self.n, self.start) return "Discrete(%d)" % self.n - def __eq__(self, other): + def __eq__(self, other) -> bool: return ( isinstance(other, Discrete) and self.n == other.n diff --git a/gym/spaces/multi_binary.py b/gym/spaces/multi_binary.py index 315da00a2..55ae4f61d 100644 --- a/gym/spaces/multi_binary.py +++ b/gym/spaces/multi_binary.py @@ -1,9 +1,11 @@ -from collections.abc import Sequence +from __future__ import annotations + +from typing import Optional, Union, Sequence import numpy as np from .space import Space -class MultiBinary(Space): +class MultiBinary(Space[np.ndarray]): """ An n-shape binary space. @@ -27,7 +29,9 @@ class MultiBinary(Space): """ - def __init__(self, n, seed=None): + def __init__( + self, n: Union[np.ndarray, Sequence[int], int], seed: Optional[int] = None + ): if isinstance(n, (Sequence, np.ndarray)): self.n = input_n = tuple(int(i) for i in n) else: @@ -38,24 +42,29 @@ class MultiBinary(Space): super().__init__(input_n, np.int8, seed) - def sample(self): + @property + def shape(self) -> tuple[int, ...]: + """Has stricter type than gym.Space - never None.""" + return self._shape # type: ignore + + def sample(self) -> np.ndarray: return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) - def contains(self, x): + def contains(self, x) -> bool: if isinstance(x, Sequence): x = np.array(x) # Promote list to array for contains check if self.shape != x.shape: return False return ((x == 0) | (x == 1)).all() - def to_jsonable(self, sample_n): + def to_jsonable(self, sample_n) -> list: return np.array(sample_n).tolist() - def from_jsonable(self, sample_n): + def from_jsonable(self, sample_n) -> list: return [np.asarray(sample) for sample in sample_n] - def __repr__(self): + def __repr__(self) -> str: return f"MultiBinary({self.n})" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, MultiBinary) and self.n == other.n diff --git a/gym/spaces/multi_discrete.py b/gym/spaces/multi_discrete.py index 02316f6b4..4c1676d7c 100644 --- a/gym/spaces/multi_discrete.py +++ b/gym/spaces/multi_discrete.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Sequence import numpy as np from gym import logger @@ -5,7 +7,7 @@ from .space import Space from .discrete import Discrete -class MultiDiscrete(Space): +class MultiDiscrete(Space[np.ndarray]): """ - The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space @@ -26,7 +28,7 @@ class MultiDiscrete(Space): """ - def __init__(self, nvec, dtype=np.int64, seed=None): + def __init__(self, nvec: list[int], dtype=np.int64, seed=None): """ nvec: vector of counts of each categorical variable """ @@ -35,15 +37,20 @@ class MultiDiscrete(Space): super().__init__(self.nvec.shape, dtype, seed) - def sample(self): + @property + def shape(self) -> tuple[int, ...]: + """Has stricter type than gym.Space - never None.""" + return self._shape # type: ignore + + def sample(self) -> np.ndarray: return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype) - def contains(self, x): + def contains(self, x) -> bool: if isinstance(x, Sequence): x = np.array(x) # Promote list to array for contains check # if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x # is within correct bounds for space dtype (even though x does not have to be unsigned) - return x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all() + return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all()) def to_jsonable(self, sample_n): return [sample.tolist() for sample in sample_n] diff --git a/gym/spaces/space.py b/gym/spaces/space.py index ad7cc62bc..c068921c7 100644 --- a/gym/spaces/space.py +++ b/gym/spaces/space.py @@ -1,12 +1,13 @@ +from __future__ import annotations + from typing import ( TypeVar, Generic, Optional, Sequence, - Union, Iterable, Mapping, - Tuple, + Type, ) import numpy as np @@ -32,8 +33,13 @@ class Space(Generic[T_cov]): not handle custom spaces properly. Use custom spaces with care. """ - def __init__(self, shape: Optional[Sequence[int]] = None, dtype=None, seed=None): - import numpy as np # takes about 300-400ms to import, so we load lazily + def __init__( + self, + shape: Optional[Sequence[int]] = None, + dtype: Optional[Type | str] = None, + seed: Optional[int] = None, + ): + import numpy as np # noqa ## takes about 300-400ms to import, so we load lazily self._shape = None if shape is None else tuple(shape) self.dtype = None if dtype is None else np.dtype(dtype) @@ -42,7 +48,7 @@ class Space(Generic[T_cov]): self.seed(seed) @property - def np_random(self) -> np.random.RandomState: + def np_random(self) -> seeding.RandomNumberGenerator: """Lazily seed the rng since this is expensive and only needed if sampling from this space. """ @@ -52,7 +58,7 @@ class Space(Generic[T_cov]): return self._np_random # type: ignore ## self.seed() call guarantees right type. @property - def shape(self) -> Optional[Tuple[int, ...]]: + def shape(self) -> Optional[tuple[int, ...]]: """Return the shape of the space as an immutable property""" return self._shape @@ -61,7 +67,7 @@ class Space(Generic[T_cov]): uniform or non-uniform sampling based on boundedness of space.""" raise NotImplementedError - def seed(self, seed: Optional[int] = None): + def seed(self, seed: Optional[int] = None) -> list: """Seed the PRNG of this space.""" self._np_random, seed = seeding.np_random(seed) return [seed] @@ -76,7 +82,7 @@ class Space(Generic[T_cov]): def __contains__(self, x) -> bool: return self.contains(x) - def __setstate__(self, state: Union[Iterable, Mapping]): + def __setstate__(self, state: Iterable | Mapping): # Don't mutate the original state state = dict(state) @@ -95,12 +101,12 @@ class Space(Generic[T_cov]): # Update our state self.__dict__.update(state) - def to_jsonable(self, sample_n): + def to_jsonable(self, sample_n: Sequence[T_cov]) -> list: """Convert a batch of samples from this space to a JSONable data type.""" # By default, assume identity is JSONable - return sample_n + return list(sample_n) - def from_jsonable(self, sample_n): + def from_jsonable(self, sample_n: list) -> list[T_cov]: """Convert a JSONable data type to a batch of samples from this space.""" # By default, assume identity is JSONable return sample_n diff --git a/gym/spaces/tuple.py b/gym/spaces/tuple.py index 0052783ed..2bff24eb5 100644 --- a/gym/spaces/tuple.py +++ b/gym/spaces/tuple.py @@ -1,8 +1,9 @@ +from typing import Iterable, List, Optional, Union import numpy as np from .space import Space -class Tuple(Space): +class Tuple(Space[tuple]): """ A tuple (i.e., product) of simpler spaces @@ -10,16 +11,18 @@ class Tuple(Space): self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3))) """ - def __init__(self, spaces, seed=None): + def __init__( + self, spaces: Iterable[Space], seed: Optional[Union[int, List[int]]] = None + ): spaces = tuple(spaces) self.spaces = spaces for space in spaces: assert isinstance( space, Space ), "Elements of the tuple must be instances of gym.Space" - super().__init__(None, None, seed) + super().__init__(None, None, seed) # type: ignore - def seed(self, seed=None): + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list: seeds = [] if isinstance(seed, list): @@ -50,10 +53,10 @@ class Tuple(Space): return seeds - def sample(self): + def sample(self) -> tuple: return tuple(space.sample() for space in self.spaces) - def contains(self, x): + def contains(self, x) -> bool: if isinstance(x, (list, np.ndarray)): x = tuple(x) # Promote list and ndarray to tuple for contains check return ( @@ -62,17 +65,17 @@ class Tuple(Space): and all(space.contains(part) for (space, part) in zip(self.spaces, x)) ) - def __repr__(self): + def __repr__(self) -> str: return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")" - def to_jsonable(self, sample_n): + def to_jsonable(self, sample_n) -> list: # serialize as list-repr of tuple of vectors return [ space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces) ] - def from_jsonable(self, sample_n): + def from_jsonable(self, sample_n) -> list: return [ sample for sample in zip( @@ -83,11 +86,11 @@ class Tuple(Space): ) ] - def __getitem__(self, index): + def __getitem__(self, index: int) -> Space: return self.spaces[index] - def __len__(self): + def __len__(self) -> int: return len(self.spaces) - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, Tuple) and self.spaces == other.spaces diff --git a/gym/spaces/utils.py b/gym/spaces/utils.py index 387be60ad..bc832b56c 100644 --- a/gym/spaces/utils.py +++ b/gym/spaces/utils.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from collections import OrderedDict from functools import singledispatch, reduce +from typing import TypeVar, Union import numpy as np import operator as op @@ -9,10 +12,11 @@ from gym.spaces import MultiDiscrete from gym.spaces import MultiBinary from gym.spaces import Tuple from gym.spaces import Dict +from gym.spaces import Space @singledispatch -def flatdim(space): +def flatdim(space: Space) -> int: """Return the number of dimensions a flattened equivalent of this space would have. @@ -24,32 +28,35 @@ def flatdim(space): @flatdim.register(Box) @flatdim.register(MultiBinary) -def _flatdim_box_multibinary(space): +def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int: return reduce(op.mul, space.shape, 1) @flatdim.register(Discrete) -def _flatdim_discrete(space): +def _flatdim_discrete(space: Discrete) -> int: return int(space.n) @flatdim.register(MultiDiscrete) -def _flatdim_multidiscrete(space): +def _flatdim_multidiscrete(space: MultiDiscrete) -> int: return int(np.sum(space.nvec)) @flatdim.register(Tuple) -def _flatdim_tuple(space): +def _flatdim_tuple(space: Tuple) -> int: return sum(flatdim(s) for s in space.spaces) @flatdim.register(Dict) -def _flatdim_dict(space): +def _flatdim_dict(space: Dict) -> int: return sum(flatdim(s) for s in space.spaces.values()) +T = TypeVar("T") + + @singledispatch -def flatten(space, x): +def flatten(space: Space[T], x: T) -> np.ndarray: """Flatten a data point from a space. This is useful when e.g. points from spaces must be passed to a neural @@ -64,19 +71,19 @@ def flatten(space, x): @flatten.register(Box) @flatten.register(MultiBinary) -def _flatten_box_multibinary(space, x): +def _flatten_box_multibinary(space, x) -> np.ndarray: return np.asarray(x, dtype=space.dtype).flatten() @flatten.register(Discrete) -def _flatten_discrete(space, x): +def _flatten_discrete(space, x) -> np.ndarray: onehot = np.zeros(space.n, dtype=space.dtype) onehot[x] = 1 return onehot @flatten.register(MultiDiscrete) -def _flatten_multidiscrete(space, x): +def _flatten_multidiscrete(space, x) -> np.ndarray: offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) offsets[1:] = np.cumsum(space.nvec.flatten()) @@ -86,17 +93,17 @@ def _flatten_multidiscrete(space, x): @flatten.register(Tuple) -def _flatten_tuple(space, x): +def _flatten_tuple(space, x) -> np.ndarray: return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)]) @flatten.register(Dict) -def _flatten_dict(space, x): +def _flatten_dict(space, x) -> np.ndarray: return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) @singledispatch -def unflatten(space, x): +def unflatten(space: Space[T], x: np.ndarray) -> T: """Unflatten a data point from a space. This reverses the transformation applied by ``flatten()``. You must ensure @@ -111,17 +118,17 @@ def unflatten(space, x): @unflatten.register(Box) @unflatten.register(MultiBinary) -def _unflatten_box_multibinary(space, x): +def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.ndarray: return np.asarray(x, dtype=space.dtype).reshape(space.shape) @unflatten.register(Discrete) -def _unflatten_discrete(space, x): +def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int: return int(np.nonzero(x)[0][0]) @unflatten.register(MultiDiscrete) -def _unflatten_multidiscrete(space, x): +def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray: offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) offsets[1:] = np.cumsum(space.nvec.flatten()) @@ -130,7 +137,7 @@ def _unflatten_multidiscrete(space, x): @unflatten.register(Tuple) -def _unflatten_tuple(space, x): +def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple: dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_) list_flattened = np.split(x, np.cumsum(dims[:-1])) return tuple( @@ -139,7 +146,7 @@ def _unflatten_tuple(space, x): @unflatten.register(Dict) -def _unflatten_dict(space, x): +def _unflatten_dict(space: Dict, x: np.ndarray) -> dict: dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_) list_flattened = np.split(x, np.cumsum(dims[:-1])) return OrderedDict( @@ -151,7 +158,7 @@ def _unflatten_dict(space, x): @singledispatch -def flatten_space(space): +def flatten_space(space: Space) -> Box: """Flatten a space into a single ``Box``. This is equivalent to ``flatten()``, but operates on the space itself. The @@ -193,32 +200,32 @@ def flatten_space(space): @flatten_space.register(Box) -def _flatten_space_box(space): +def _flatten_space_box(space: Box) -> Box: return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype) @flatten_space.register(Discrete) @flatten_space.register(MultiBinary) @flatten_space.register(MultiDiscrete) -def _flatten_space_binary(space): +def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box: return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype) @flatten_space.register(Tuple) -def _flatten_space_tuple(space): - space = [flatten_space(s) for s in space.spaces] +def _flatten_space_tuple(space: Tuple) -> Box: + space_list = [flatten_space(s) for s in space.spaces] return Box( - low=np.concatenate([s.low for s in space]), - high=np.concatenate([s.high for s in space]), - dtype=np.result_type(*[s.dtype for s in space]), + low=np.concatenate([s.low for s in space_list]), + high=np.concatenate([s.high for s in space_list]), + dtype=np.result_type(*[s.dtype for s in space_list]), ) @flatten_space.register(Dict) -def _flatten_space_dict(space): - space = [flatten_space(s) for s in space.spaces.values()] +def _flatten_space_dict(space: Dict) -> Box: + space_list = [flatten_space(s) for s in space.spaces.values()] return Box( - low=np.concatenate([s.low for s in space]), - high=np.concatenate([s.high for s in space]), - dtype=np.result_type(*[s.dtype for s in space]), + low=np.concatenate([s.low for s in space_list]), + high=np.concatenate([s.high for s in space_list]), + dtype=np.result_type(*[s.dtype for s in space_list]), ) diff --git a/gym/utils/seeding.py b/gym/utils/seeding.py index 921632822..84571a479 100644 --- a/gym/utils/seeding.py +++ b/gym/utils/seeding.py @@ -1,5 +1,5 @@ import hashlib -from typing import Optional, List, Union +from typing import Optional, List, Tuple, Union, Any import os import struct @@ -10,7 +10,7 @@ from gym import error from gym.logger import deprecation -def np_random(seed: Optional[int] = None) -> tuple: +def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]: if seed is not None and not (isinstance(seed, int) and 0 <= seed): raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}") diff --git a/pyproject.toml b/pyproject.toml index 181f81b52..7dedbd7fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ include = [ "gym/version.py", "gym/logger.py", "gym/envs/registration.py", - "gym/spaces/space.py", + "gym/spaces/**.py", "gym/core.py", "gym/utils/seeding.py" ]