typing in gym.spaces (#2541)

* typing in spaces.Box and spaces.Discrete

* adds typing to dict and tuple spaces

* Typecheck all spaces

* Explicit regex to include all files under space folder

* Style: use native types and __future__ annotations

* Allow only specific strings for Box.is_bounded args

* Add typing to changes from #2517

* Remove Literal as it's not supported by py3.7

* Use more recent version of pyright

* Avoid name clash for type checker

* Revert "Avoid name clash for type checker"

This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1.

* Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852

* rebase and add typing for `_short_repr`
This commit is contained in:
Ilya Kamen
2022-01-24 23:22:11 +01:00
committed by GitHub
parent fcbff7de12
commit ad79b0ad0f
11 changed files with 199 additions and 146 deletions

View File

@@ -22,7 +22,7 @@ jobs:
python-version: [ "3.7"] python-version: [ "3.7"]
fail-fast: false fail-fast: false
env: env:
PYRIGHT_VERSION: 1.1.183 PYRIGHT_VERSION: 1.1.204
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: actions/setup-python@v2 - uses: actions/setup-python@v2

View File

@@ -1,10 +1,14 @@
from __future__ import annotations
from typing import Tuple, SupportsFloat, Union, Type, Optional, Sequence
import numpy as np import numpy as np
from .space import Space from .space import Space
from gym import logger from gym import logger
def _short_repr(arr): def _short_repr(arr: np.ndarray) -> str:
"""Create a shortened string representation of a numpy array. """Create a shortened string representation of a numpy array.
If arr is a multiple of the all-ones vector, return a string representation of the multiplier. If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
@@ -15,7 +19,7 @@ def _short_repr(arr):
return str(arr) return str(arr)
class Box(Space): class Box(Space[np.ndarray]):
""" """
A (possibly unbounded) box in R^n. Specifically, a Box represents the A (possibly unbounded) box in R^n. Specifically, a Box represents the
Cartesian product of n closed intervals. Each interval has the form of one Cartesian product of n closed intervals. Each interval has the form of one
@@ -33,66 +37,47 @@ class Box(Space):
""" """
def __init__(self, low, high, shape=None, dtype=np.float32, seed=None): def __init__(
self,
low: Union[SupportsFloat, np.ndarray],
high: Union[SupportsFloat, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: Type = np.float32,
seed: Optional[int] = None,
):
assert dtype is not None, "dtype must be explicitly provided. " assert dtype is not None, "dtype must be explicitly provided. "
self.dtype = np.dtype(dtype) self.dtype = np.dtype(dtype)
# determine shape if it isn't provided directly # determine shape if it isn't provided directly
if shape is not None: if shape is not None:
shape = tuple(shape) shape = tuple(shape)
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match provided shape"
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match provided shape"
elif not np.isscalar(low): elif not np.isscalar(low):
shape = low.shape shape = low.shape # type: ignore
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match low.shape"
elif not np.isscalar(high): elif not np.isscalar(high):
shape = high.shape shape = high.shape # type: ignore
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match high.shape"
else: else:
raise ValueError( raise ValueError(
"shape must be provided or inferred from the shapes of low or high" "shape must be provided or inferred from the shapes of low or high"
) )
assert isinstance(shape, tuple)
# handle infinite bounds and broadcast at the same time if needed low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
if np.isscalar(low): high = _broadcast(high, dtype, shape, inf_sign="+")
low = get_inf(dtype, "-") if np.isinf(low) else low
low = np.full(shape, low, dtype=dtype)
else:
if np.any(np.isinf(low)):
# create new array with dtype, but maintain old one to preserve np.inf
temp_low = low.astype(dtype)
temp_low[np.isinf(low)] = get_inf(dtype, "-")
low = temp_low
if np.isscalar(high): assert isinstance(low, np.ndarray)
high = get_inf(dtype, "+") if np.isinf(high) else high assert low.shape == shape, "low.shape doesn't match provided shape"
high = np.full(shape, high, dtype=dtype) assert isinstance(high, np.ndarray)
else: assert high.shape == shape, "high.shape doesn't match provided shape"
if np.any(np.isinf(high)):
# create new array with dtype, but maintain old one to preserve np.inf
temp_high = high.astype(dtype)
temp_high[np.isinf(high)] = get_inf(dtype, "+")
high = temp_high
self._shape = shape self._shape: Tuple[int, ...] = shape
self.low = low
self.high = high
low_precision = get_precision(self.low.dtype) low_precision = get_precision(low.dtype)
high_precision = get_precision(self.high.dtype) high_precision = get_precision(high.dtype)
dtype_precision = get_precision(self.dtype) dtype_precision = get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision: if min(low_precision, high_precision) > dtype_precision: # type: ignore
logger.warn(f"Box bound precision lowered by casting to {self.dtype}") logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
self.low = self.low.astype(self.dtype) self.low = low.astype(self.dtype)
self.high = self.high.astype(self.dtype) self.high = high.astype(self.dtype)
self.low_repr = _short_repr(self.low) self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high) self.high_repr = _short_repr(self.high)
@@ -103,9 +88,14 @@ class Box(Space):
super().__init__(self.shape, self.dtype, seed) super().__init__(self.shape, self.dtype, seed)
def is_bounded(self, manner="both"): @property
below = np.all(self.bounded_below) def shape(self) -> Tuple[int, ...]:
above = np.all(self.bounded_above) """Has stricter type than gym.Space - never None."""
return self._shape
def is_bounded(self, manner: str = "both") -> bool:
below = bool(np.all(self.bounded_below))
above = bool(np.all(self.bounded_above))
if manner == "both": if manner == "both":
return below and above return below and above
elif manner == "below": elif manner == "below":
@@ -115,7 +105,7 @@ class Box(Space):
else: else:
raise ValueError("manner is not in {'below', 'above', 'both'}") raise ValueError("manner is not in {'below', 'above', 'both'}")
def sample(self): def sample(self) -> np.ndarray:
""" """
Generates a single random sample inside of the Box. Generates a single random sample inside of the Box.
@@ -158,12 +148,12 @@ class Box(Space):
return sample.astype(self.dtype) return sample.astype(self.dtype)
def contains(self, x): def contains(self, x) -> bool:
if not isinstance(x, np.ndarray): if not isinstance(x, np.ndarray):
logger.warn("Casting input x to numpy array.") logger.warn("Casting input x to numpy array.")
x = np.asarray(x, dtype=self.dtype) x = np.asarray(x, dtype=self.dtype)
return ( return bool(
np.can_cast(x.dtype, self.dtype) np.can_cast(x.dtype, self.dtype)
and x.shape == self.shape and x.shape == self.shape
and np.all(x >= self.low) and np.all(x >= self.low)
@@ -173,13 +163,13 @@ class Box(Space):
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
return np.array(sample_n).tolist() return np.array(sample_n).tolist()
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> list[np.ndarray]:
return [np.asarray(sample) for sample in sample_n] return [np.asarray(sample) for sample in sample_n]
def __repr__(self): def __repr__(self) -> str:
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})" return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
def __eq__(self, other): def __eq__(self, other) -> bool:
return ( return (
isinstance(other, Box) isinstance(other, Box)
and (self.shape == other.shape) and (self.shape == other.shape)
@@ -188,7 +178,7 @@ class Box(Space):
) )
def get_inf(dtype, sign): def get_inf(dtype, sign: str) -> SupportsFloat:
"""Returns an infinite that doesn't break things. """Returns an infinite that doesn't break things.
`dtype` must be an `np.dtype` `dtype` must be an `np.dtype`
`bound` must be either `min` or `max` `bound` must be either `min` or `max`
@@ -211,8 +201,28 @@ def get_inf(dtype, sign):
raise ValueError(f"Unknown dtype {dtype} for infinite bounds") raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
def get_precision(dtype): def get_precision(dtype) -> SupportsFloat:
if np.issubdtype(dtype, np.floating): if np.issubdtype(dtype, np.floating):
return np.finfo(dtype).precision return np.finfo(dtype).precision
else: else:
return np.inf return np.inf
def _broadcast(
value: Union[SupportsFloat, np.ndarray],
dtype,
shape: tuple[int, ...],
inf_sign: str,
) -> np.ndarray:
"""handle infinite bounds and broadcast at the same time if needed"""
if np.isscalar(value):
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
value = np.full(shape, value, dtype=dtype)
else:
assert isinstance(value, np.ndarray)
if np.any(np.isinf(value)):
# create new array with dtype, but maintain old one to preserve np.inf
temp = value.astype(dtype)
temp[np.isinf(value)] = get_inf(dtype, inf_sign)
value = temp
return value

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Dict as TypingDict
import numpy as np import numpy as np
from .space import Space from .space import Space
class Dict(Space, Mapping): class Dict(Space[TypingDict[str, Space]], Mapping):
""" """
A dictionary of simpler spaces. A dictionary of simpler spaces.
@@ -34,7 +37,12 @@ class Dict(Space, Mapping):
}) })
""" """
def __init__(self, spaces=None, seed=None, **spaces_kwargs): def __init__(
self,
spaces: dict[str, Space] | None = None,
seed: dict | int | None = None,
**spaces_kwargs: Space
):
assert (spaces is None) or ( assert (spaces is None) or (
not spaces_kwargs not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" ), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
@@ -57,10 +65,10 @@ class Dict(Space, Mapping):
space, Space space, Space
), "Values of the dict should be instances of gym.Space" ), "Values of the dict should be instances of gym.Space"
super().__init__( super().__init__(
None, None, seed None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling ) # None for shape and dtype, since it'll require special handling
def seed(self, seed=None): def seed(self, seed: dict | int | None = None) -> list:
seeds = [] seeds = []
if isinstance(seed, dict): if isinstance(seed, dict):
for key, seed_key in zip(self.spaces, seed): for key, seed_key in zip(self.spaces, seed):
@@ -97,10 +105,10 @@ class Dict(Space, Mapping):
return seeds return seeds
def sample(self): def sample(self) -> dict:
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()]) return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
def contains(self, x): def contains(self, x) -> bool:
if not isinstance(x, dict) or len(x) != len(self.spaces): if not isinstance(x, dict) or len(x) != len(self.spaces):
return False return False
for k, space in self.spaces.items(): for k, space in self.spaces.items():
@@ -119,29 +127,30 @@ class Dict(Space, Mapping):
def __iter__(self): def __iter__(self):
yield from self.spaces yield from self.spaces
def __len__(self): def __len__(self) -> int:
return len(self.spaces) return len(self.spaces)
def __repr__(self): def __repr__(self) -> str:
return ( return (
"Dict(" "Dict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")" + ")"
) )
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n: list) -> dict:
# serialize as dict-repr of vectors # serialize as dict-repr of vectors
return { return {
key: space.to_jsonable([sample[key] for sample in sample_n]) key: space.to_jsonable([sample[key] for sample in sample_n])
for key, space in self.spaces.items() for key, space in self.spaces.items()
} }
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n: dict[str, list]) -> list:
dict_of_list = {} dict_of_list: dict[str, list] = {}
for key, space in self.spaces.items(): for key, space in self.spaces.items():
dict_of_list[key] = space.from_jsonable(sample_n[key]) dict_of_list[key] = space.from_jsonable(sample_n[key])
ret = [] ret = []
for i, _ in enumerate(dict_of_list[key]): n_elements = len(next(iter(dict_of_list.values())))
for i in range(n_elements):
entry = {} entry = {}
for key, value in dict_of_list.items(): for key, value in dict_of_list.items():
entry[key] = value[i] entry[key] = value[i]

View File

@@ -1,8 +1,10 @@
from typing import Optional
import numpy as np import numpy as np
from .space import Space from .space import Space
class Discrete(Space): class Discrete(Space[int]):
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`. r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
A start value can be optionally specified to shift the range A start value can be optionally specified to shift the range
@@ -15,33 +17,33 @@ class Discrete(Space):
""" """
def __init__(self, n, seed=None, start=0): def __init__(self, n: int, seed: Optional[int] = None, start: int = 0):
assert n > 0, "n (counts) have to be positive" assert n > 0, "n (counts) have to be positive"
assert isinstance(start, (int, np.integer)) assert isinstance(start, (int, np.integer))
self.n = int(n) self.n = int(n)
self.start = int(start) self.start = int(start)
super().__init__((), np.int64, seed) super().__init__((), np.int64, seed)
def sample(self): def sample(self) -> int:
return self.start + self.np_random.randint(self.n) return self.start + self.np_random.randint(self.n)
def contains(self, x): def contains(self, x) -> bool:
if isinstance(x, int): if isinstance(x, int):
as_int = x as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and ( elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == () x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
): ):
as_int = int(x) as_int = int(x) # type: ignore
else: else:
return False return False
return self.start <= as_int < self.start + self.n return self.start <= as_int < self.start + self.n
def __repr__(self): def __repr__(self) -> str:
if self.start != 0: if self.start != 0:
return "Discrete(%d, start=%d)" % (self.n, self.start) return "Discrete(%d, start=%d)" % (self.n, self.start)
return "Discrete(%d)" % self.n return "Discrete(%d)" % self.n
def __eq__(self, other): def __eq__(self, other) -> bool:
return ( return (
isinstance(other, Discrete) isinstance(other, Discrete)
and self.n == other.n and self.n == other.n

View File

@@ -1,9 +1,11 @@
from collections.abc import Sequence from __future__ import annotations
from typing import Optional, Union, Sequence
import numpy as np import numpy as np
from .space import Space from .space import Space
class MultiBinary(Space): class MultiBinary(Space[np.ndarray]):
""" """
An n-shape binary space. An n-shape binary space.
@@ -27,7 +29,9 @@ class MultiBinary(Space):
""" """
def __init__(self, n, seed=None): def __init__(
self, n: Union[np.ndarray, Sequence[int], int], seed: Optional[int] = None
):
if isinstance(n, (Sequence, np.ndarray)): if isinstance(n, (Sequence, np.ndarray)):
self.n = input_n = tuple(int(i) for i in n) self.n = input_n = tuple(int(i) for i in n)
else: else:
@@ -38,24 +42,29 @@ class MultiBinary(Space):
super().__init__(input_n, np.int8, seed) super().__init__(input_n, np.int8, seed)
def sample(self): @property
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
def sample(self) -> np.ndarray:
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x): def contains(self, x) -> bool:
if isinstance(x, Sequence): if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape: if self.shape != x.shape:
return False return False
return ((x == 0) | (x == 1)).all() return ((x == 0) | (x == 1)).all()
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n) -> list:
return np.array(sample_n).tolist() return np.array(sample_n).tolist()
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n) -> list:
return [np.asarray(sample) for sample in sample_n] return [np.asarray(sample) for sample in sample_n]
def __repr__(self): def __repr__(self) -> str:
return f"MultiBinary({self.n})" return f"MultiBinary({self.n})"
def __eq__(self, other): def __eq__(self, other) -> bool:
return isinstance(other, MultiBinary) and self.n == other.n return isinstance(other, MultiBinary) and self.n == other.n

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
import numpy as np import numpy as np
from gym import logger from gym import logger
@@ -5,7 +7,7 @@ from .space import Space
from .discrete import Discrete from .discrete import Discrete
class MultiDiscrete(Space): class MultiDiscrete(Space[np.ndarray]):
""" """
- The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each - The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
@@ -26,7 +28,7 @@ class MultiDiscrete(Space):
""" """
def __init__(self, nvec, dtype=np.int64, seed=None): def __init__(self, nvec: list[int], dtype=np.int64, seed=None):
""" """
nvec: vector of counts of each categorical variable nvec: vector of counts of each categorical variable
""" """
@@ -35,15 +37,20 @@ class MultiDiscrete(Space):
super().__init__(self.nvec.shape, dtype, seed) super().__init__(self.nvec.shape, dtype, seed)
def sample(self): @property
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
def sample(self) -> np.ndarray:
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype) return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x): def contains(self, x) -> bool:
if isinstance(x, Sequence): if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check x = np.array(x) # Promote list to array for contains check
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x # if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
# is within correct bounds for space dtype (even though x does not have to be unsigned) # is within correct bounds for space dtype (even though x does not have to be unsigned)
return x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all() return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all())
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
return [sample.tolist() for sample in sample_n] return [sample.tolist() for sample in sample_n]

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
from typing import ( from typing import (
TypeVar, TypeVar,
Generic, Generic,
Optional, Optional,
Sequence, Sequence,
Union,
Iterable, Iterable,
Mapping, Mapping,
Tuple, Type,
) )
import numpy as np import numpy as np
@@ -32,8 +33,13 @@ class Space(Generic[T_cov]):
not handle custom spaces properly. Use custom spaces with care. not handle custom spaces properly. Use custom spaces with care.
""" """
def __init__(self, shape: Optional[Sequence[int]] = None, dtype=None, seed=None): def __init__(
import numpy as np # takes about 300-400ms to import, so we load lazily self,
shape: Optional[Sequence[int]] = None,
dtype: Optional[Type | str] = None,
seed: Optional[int] = None,
):
import numpy as np # noqa ## takes about 300-400ms to import, so we load lazily
self._shape = None if shape is None else tuple(shape) self._shape = None if shape is None else tuple(shape)
self.dtype = None if dtype is None else np.dtype(dtype) self.dtype = None if dtype is None else np.dtype(dtype)
@@ -42,7 +48,7 @@ class Space(Generic[T_cov]):
self.seed(seed) self.seed(seed)
@property @property
def np_random(self) -> np.random.RandomState: def np_random(self) -> seeding.RandomNumberGenerator:
"""Lazily seed the rng since this is expensive and only needed if """Lazily seed the rng since this is expensive and only needed if
sampling from this space. sampling from this space.
""" """
@@ -52,7 +58,7 @@ class Space(Generic[T_cov]):
return self._np_random # type: ignore ## self.seed() call guarantees right type. return self._np_random # type: ignore ## self.seed() call guarantees right type.
@property @property
def shape(self) -> Optional[Tuple[int, ...]]: def shape(self) -> Optional[tuple[int, ...]]:
"""Return the shape of the space as an immutable property""" """Return the shape of the space as an immutable property"""
return self._shape return self._shape
@@ -61,7 +67,7 @@ class Space(Generic[T_cov]):
uniform or non-uniform sampling based on boundedness of space.""" uniform or non-uniform sampling based on boundedness of space."""
raise NotImplementedError raise NotImplementedError
def seed(self, seed: Optional[int] = None): def seed(self, seed: Optional[int] = None) -> list:
"""Seed the PRNG of this space.""" """Seed the PRNG of this space."""
self._np_random, seed = seeding.np_random(seed) self._np_random, seed = seeding.np_random(seed)
return [seed] return [seed]
@@ -76,7 +82,7 @@ class Space(Generic[T_cov]):
def __contains__(self, x) -> bool: def __contains__(self, x) -> bool:
return self.contains(x) return self.contains(x)
def __setstate__(self, state: Union[Iterable, Mapping]): def __setstate__(self, state: Iterable | Mapping):
# Don't mutate the original state # Don't mutate the original state
state = dict(state) state = dict(state)
@@ -95,12 +101,12 @@ class Space(Generic[T_cov]):
# Update our state # Update our state
self.__dict__.update(state) self.__dict__.update(state)
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n: Sequence[T_cov]) -> list:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
# By default, assume identity is JSONable # By default, assume identity is JSONable
return sample_n return list(sample_n)
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n: list) -> list[T_cov]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
# By default, assume identity is JSONable # By default, assume identity is JSONable
return sample_n return sample_n

View File

@@ -1,8 +1,9 @@
from typing import Iterable, List, Optional, Union
import numpy as np import numpy as np
from .space import Space from .space import Space
class Tuple(Space): class Tuple(Space[tuple]):
""" """
A tuple (i.e., product) of simpler spaces A tuple (i.e., product) of simpler spaces
@@ -10,16 +11,18 @@ class Tuple(Space):
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3))) self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
""" """
def __init__(self, spaces, seed=None): def __init__(
self, spaces: Iterable[Space], seed: Optional[Union[int, List[int]]] = None
):
spaces = tuple(spaces) spaces = tuple(spaces)
self.spaces = spaces self.spaces = spaces
for space in spaces: for space in spaces:
assert isinstance( assert isinstance(
space, Space space, Space
), "Elements of the tuple must be instances of gym.Space" ), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed) super().__init__(None, None, seed) # type: ignore
def seed(self, seed=None): def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
seeds = [] seeds = []
if isinstance(seed, list): if isinstance(seed, list):
@@ -50,10 +53,10 @@ class Tuple(Space):
return seeds return seeds
def sample(self): def sample(self) -> tuple:
return tuple(space.sample() for space in self.spaces) return tuple(space.sample() for space in self.spaces)
def contains(self, x): def contains(self, x) -> bool:
if isinstance(x, (list, np.ndarray)): if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check x = tuple(x) # Promote list and ndarray to tuple for contains check
return ( return (
@@ -62,17 +65,17 @@ class Tuple(Space):
and all(space.contains(part) for (space, part) in zip(self.spaces, x)) and all(space.contains(part) for (space, part) in zip(self.spaces, x))
) )
def __repr__(self): def __repr__(self) -> str:
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")" return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n) -> list:
# serialize as list-repr of tuple of vectors # serialize as list-repr of tuple of vectors
return [ return [
space.to_jsonable([sample[i] for sample in sample_n]) space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces) for i, space in enumerate(self.spaces)
] ]
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n) -> list:
return [ return [
sample sample
for sample in zip( for sample in zip(
@@ -83,11 +86,11 @@ class Tuple(Space):
) )
] ]
def __getitem__(self, index): def __getitem__(self, index: int) -> Space:
return self.spaces[index] return self.spaces[index]
def __len__(self): def __len__(self) -> int:
return len(self.spaces) return len(self.spaces)
def __eq__(self, other): def __eq__(self, other) -> bool:
return isinstance(other, Tuple) and self.spaces == other.spaces return isinstance(other, Tuple) and self.spaces == other.spaces

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from functools import singledispatch, reduce from functools import singledispatch, reduce
from typing import TypeVar, Union
import numpy as np import numpy as np
import operator as op import operator as op
@@ -9,10 +12,11 @@ from gym.spaces import MultiDiscrete
from gym.spaces import MultiBinary from gym.spaces import MultiBinary
from gym.spaces import Tuple from gym.spaces import Tuple
from gym.spaces import Dict from gym.spaces import Dict
from gym.spaces import Space
@singledispatch @singledispatch
def flatdim(space): def flatdim(space: Space) -> int:
"""Return the number of dimensions a flattened equivalent of this space """Return the number of dimensions a flattened equivalent of this space
would have. would have.
@@ -24,32 +28,35 @@ def flatdim(space):
@flatdim.register(Box) @flatdim.register(Box)
@flatdim.register(MultiBinary) @flatdim.register(MultiBinary)
def _flatdim_box_multibinary(space): def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
return reduce(op.mul, space.shape, 1) return reduce(op.mul, space.shape, 1)
@flatdim.register(Discrete) @flatdim.register(Discrete)
def _flatdim_discrete(space): def _flatdim_discrete(space: Discrete) -> int:
return int(space.n) return int(space.n)
@flatdim.register(MultiDiscrete) @flatdim.register(MultiDiscrete)
def _flatdim_multidiscrete(space): def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
return int(np.sum(space.nvec)) return int(np.sum(space.nvec))
@flatdim.register(Tuple) @flatdim.register(Tuple)
def _flatdim_tuple(space): def _flatdim_tuple(space: Tuple) -> int:
return sum(flatdim(s) for s in space.spaces) return sum(flatdim(s) for s in space.spaces)
@flatdim.register(Dict) @flatdim.register(Dict)
def _flatdim_dict(space): def _flatdim_dict(space: Dict) -> int:
return sum(flatdim(s) for s in space.spaces.values()) return sum(flatdim(s) for s in space.spaces.values())
T = TypeVar("T")
@singledispatch @singledispatch
def flatten(space, x): def flatten(space: Space[T], x: T) -> np.ndarray:
"""Flatten a data point from a space. """Flatten a data point from a space.
This is useful when e.g. points from spaces must be passed to a neural This is useful when e.g. points from spaces must be passed to a neural
@@ -64,19 +71,19 @@ def flatten(space, x):
@flatten.register(Box) @flatten.register(Box)
@flatten.register(MultiBinary) @flatten.register(MultiBinary)
def _flatten_box_multibinary(space, x): def _flatten_box_multibinary(space, x) -> np.ndarray:
return np.asarray(x, dtype=space.dtype).flatten() return np.asarray(x, dtype=space.dtype).flatten()
@flatten.register(Discrete) @flatten.register(Discrete)
def _flatten_discrete(space, x): def _flatten_discrete(space, x) -> np.ndarray:
onehot = np.zeros(space.n, dtype=space.dtype) onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x] = 1 onehot[x] = 1
return onehot return onehot
@flatten.register(MultiDiscrete) @flatten.register(MultiDiscrete)
def _flatten_multidiscrete(space, x): def _flatten_multidiscrete(space, x) -> np.ndarray:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten()) offsets[1:] = np.cumsum(space.nvec.flatten())
@@ -86,17 +93,17 @@ def _flatten_multidiscrete(space, x):
@flatten.register(Tuple) @flatten.register(Tuple)
def _flatten_tuple(space, x): def _flatten_tuple(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)]) return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
@flatten.register(Dict) @flatten.register(Dict)
def _flatten_dict(space, x): def _flatten_dict(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
@singledispatch @singledispatch
def unflatten(space, x): def unflatten(space: Space[T], x: np.ndarray) -> T:
"""Unflatten a data point from a space. """Unflatten a data point from a space.
This reverses the transformation applied by ``flatten()``. You must ensure This reverses the transformation applied by ``flatten()``. You must ensure
@@ -111,17 +118,17 @@ def unflatten(space, x):
@unflatten.register(Box) @unflatten.register(Box)
@unflatten.register(MultiBinary) @unflatten.register(MultiBinary)
def _unflatten_box_multibinary(space, x): def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.ndarray:
return np.asarray(x, dtype=space.dtype).reshape(space.shape) return np.asarray(x, dtype=space.dtype).reshape(space.shape)
@unflatten.register(Discrete) @unflatten.register(Discrete)
def _unflatten_discrete(space, x): def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
return int(np.nonzero(x)[0][0]) return int(np.nonzero(x)[0][0])
@unflatten.register(MultiDiscrete) @unflatten.register(MultiDiscrete)
def _unflatten_multidiscrete(space, x): def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten()) offsets[1:] = np.cumsum(space.nvec.flatten())
@@ -130,7 +137,7 @@ def _unflatten_multidiscrete(space, x):
@unflatten.register(Tuple) @unflatten.register(Tuple)
def _unflatten_tuple(space, x): def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple:
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_) dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1])) list_flattened = np.split(x, np.cumsum(dims[:-1]))
return tuple( return tuple(
@@ -139,7 +146,7 @@ def _unflatten_tuple(space, x):
@unflatten.register(Dict) @unflatten.register(Dict)
def _unflatten_dict(space, x): def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_) dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1])) list_flattened = np.split(x, np.cumsum(dims[:-1]))
return OrderedDict( return OrderedDict(
@@ -151,7 +158,7 @@ def _unflatten_dict(space, x):
@singledispatch @singledispatch
def flatten_space(space): def flatten_space(space: Space) -> Box:
"""Flatten a space into a single ``Box``. """Flatten a space into a single ``Box``.
This is equivalent to ``flatten()``, but operates on the space itself. The This is equivalent to ``flatten()``, but operates on the space itself. The
@@ -193,32 +200,32 @@ def flatten_space(space):
@flatten_space.register(Box) @flatten_space.register(Box)
def _flatten_space_box(space): def _flatten_space_box(space: Box) -> Box:
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype) return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
@flatten_space.register(Discrete) @flatten_space.register(Discrete)
@flatten_space.register(MultiBinary) @flatten_space.register(MultiBinary)
@flatten_space.register(MultiDiscrete) @flatten_space.register(MultiDiscrete)
def _flatten_space_binary(space): def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype) return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
@flatten_space.register(Tuple) @flatten_space.register(Tuple)
def _flatten_space_tuple(space): def _flatten_space_tuple(space: Tuple) -> Box:
space = [flatten_space(s) for s in space.spaces] space_list = [flatten_space(s) for s in space.spaces]
return Box( return Box(
low=np.concatenate([s.low for s in space]), low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space]), high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space]), dtype=np.result_type(*[s.dtype for s in space_list]),
) )
@flatten_space.register(Dict) @flatten_space.register(Dict)
def _flatten_space_dict(space): def _flatten_space_dict(space: Dict) -> Box:
space = [flatten_space(s) for s in space.spaces.values()] space_list = [flatten_space(s) for s in space.spaces.values()]
return Box( return Box(
low=np.concatenate([s.low for s in space]), low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space]), high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space]), dtype=np.result_type(*[s.dtype for s in space_list]),
) )

View File

@@ -1,5 +1,5 @@
import hashlib import hashlib
from typing import Optional, List, Union from typing import Optional, List, Tuple, Union, Any
import os import os
import struct import struct
@@ -10,7 +10,7 @@ from gym import error
from gym.logger import deprecation from gym.logger import deprecation
def np_random(seed: Optional[int] = None) -> tuple: def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
if seed is not None and not (isinstance(seed, int) and 0 <= seed): if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}") raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")

View File

@@ -4,7 +4,7 @@ include = [
"gym/version.py", "gym/version.py",
"gym/logger.py", "gym/logger.py",
"gym/envs/registration.py", "gym/envs/registration.py",
"gym/spaces/space.py", "gym/spaces/**.py",
"gym/core.py", "gym/core.py",
"gym/utils/seeding.py" "gym/utils/seeding.py"
] ]