mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
typing in gym.spaces (#2541)
* typing in spaces.Box and spaces.Discrete * adds typing to dict and tuple spaces * Typecheck all spaces * Explicit regex to include all files under space folder * Style: use native types and __future__ annotations * Allow only specific strings for Box.is_bounded args * Add typing to changes from #2517 * Remove Literal as it's not supported by py3.7 * Use more recent version of pyright * Avoid name clash for type checker * Revert "Avoid name clash for type checker" This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1. * Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852 * rebase and add typing for `_short_repr`
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import singledispatch, reduce
|
||||
from typing import TypeVar, Union
|
||||
import numpy as np
|
||||
import operator as op
|
||||
|
||||
@@ -9,10 +12,11 @@ from gym.spaces import MultiDiscrete
|
||||
from gym.spaces import MultiBinary
|
||||
from gym.spaces import Tuple
|
||||
from gym.spaces import Dict
|
||||
from gym.spaces import Space
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatdim(space):
|
||||
def flatdim(space: Space) -> int:
|
||||
"""Return the number of dimensions a flattened equivalent of this space
|
||||
would have.
|
||||
|
||||
@@ -24,32 +28,35 @@ def flatdim(space):
|
||||
|
||||
@flatdim.register(Box)
|
||||
@flatdim.register(MultiBinary)
|
||||
def _flatdim_box_multibinary(space):
|
||||
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
|
||||
return reduce(op.mul, space.shape, 1)
|
||||
|
||||
|
||||
@flatdim.register(Discrete)
|
||||
def _flatdim_discrete(space):
|
||||
def _flatdim_discrete(space: Discrete) -> int:
|
||||
return int(space.n)
|
||||
|
||||
|
||||
@flatdim.register(MultiDiscrete)
|
||||
def _flatdim_multidiscrete(space):
|
||||
def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
|
||||
return int(np.sum(space.nvec))
|
||||
|
||||
|
||||
@flatdim.register(Tuple)
|
||||
def _flatdim_tuple(space):
|
||||
def _flatdim_tuple(space: Tuple) -> int:
|
||||
return sum(flatdim(s) for s in space.spaces)
|
||||
|
||||
|
||||
@flatdim.register(Dict)
|
||||
def _flatdim_dict(space):
|
||||
def _flatdim_dict(space: Dict) -> int:
|
||||
return sum(flatdim(s) for s in space.spaces.values())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten(space, x):
|
||||
def flatten(space: Space[T], x: T) -> np.ndarray:
|
||||
"""Flatten a data point from a space.
|
||||
|
||||
This is useful when e.g. points from spaces must be passed to a neural
|
||||
@@ -64,19 +71,19 @@ def flatten(space, x):
|
||||
|
||||
@flatten.register(Box)
|
||||
@flatten.register(MultiBinary)
|
||||
def _flatten_box_multibinary(space, x):
|
||||
def _flatten_box_multibinary(space, x) -> np.ndarray:
|
||||
return np.asarray(x, dtype=space.dtype).flatten()
|
||||
|
||||
|
||||
@flatten.register(Discrete)
|
||||
def _flatten_discrete(space, x):
|
||||
def _flatten_discrete(space, x) -> np.ndarray:
|
||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||
onehot[x] = 1
|
||||
return onehot
|
||||
|
||||
|
||||
@flatten.register(MultiDiscrete)
|
||||
def _flatten_multidiscrete(space, x):
|
||||
def _flatten_multidiscrete(space, x) -> np.ndarray:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
@@ -86,17 +93,17 @@ def _flatten_multidiscrete(space, x):
|
||||
|
||||
|
||||
@flatten.register(Tuple)
|
||||
def _flatten_tuple(space, x):
|
||||
def _flatten_tuple(space, x) -> np.ndarray:
|
||||
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
||||
|
||||
|
||||
@flatten.register(Dict)
|
||||
def _flatten_dict(space, x):
|
||||
def _flatten_dict(space, x) -> np.ndarray:
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
|
||||
|
||||
@singledispatch
|
||||
def unflatten(space, x):
|
||||
def unflatten(space: Space[T], x: np.ndarray) -> T:
|
||||
"""Unflatten a data point from a space.
|
||||
|
||||
This reverses the transformation applied by ``flatten()``. You must ensure
|
||||
@@ -111,17 +118,17 @@ def unflatten(space, x):
|
||||
|
||||
@unflatten.register(Box)
|
||||
@unflatten.register(MultiBinary)
|
||||
def _unflatten_box_multibinary(space, x):
|
||||
def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.ndarray:
|
||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||
|
||||
|
||||
@unflatten.register(Discrete)
|
||||
def _unflatten_discrete(space, x):
|
||||
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
||||
return int(np.nonzero(x)[0][0])
|
||||
|
||||
|
||||
@unflatten.register(MultiDiscrete)
|
||||
def _unflatten_multidiscrete(space, x):
|
||||
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
@@ -130,7 +137,7 @@ def _unflatten_multidiscrete(space, x):
|
||||
|
||||
|
||||
@unflatten.register(Tuple)
|
||||
def _unflatten_tuple(space, x):
|
||||
def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple:
|
||||
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
|
||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||
return tuple(
|
||||
@@ -139,7 +146,7 @@ def _unflatten_tuple(space, x):
|
||||
|
||||
|
||||
@unflatten.register(Dict)
|
||||
def _unflatten_dict(space, x):
|
||||
def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
|
||||
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
|
||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||
return OrderedDict(
|
||||
@@ -151,7 +158,7 @@ def _unflatten_dict(space, x):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space):
|
||||
def flatten_space(space: Space) -> Box:
|
||||
"""Flatten a space into a single ``Box``.
|
||||
|
||||
This is equivalent to ``flatten()``, but operates on the space itself. The
|
||||
@@ -193,32 +200,32 @@ def flatten_space(space):
|
||||
|
||||
|
||||
@flatten_space.register(Box)
|
||||
def _flatten_space_box(space):
|
||||
def _flatten_space_box(space: Box) -> Box:
|
||||
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
|
||||
|
||||
|
||||
@flatten_space.register(Discrete)
|
||||
@flatten_space.register(MultiBinary)
|
||||
@flatten_space.register(MultiDiscrete)
|
||||
def _flatten_space_binary(space):
|
||||
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
|
||||
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
|
||||
|
||||
|
||||
@flatten_space.register(Tuple)
|
||||
def _flatten_space_tuple(space):
|
||||
space = [flatten_space(s) for s in space.spaces]
|
||||
def _flatten_space_tuple(space: Tuple) -> Box:
|
||||
space_list = [flatten_space(s) for s in space.spaces]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space]),
|
||||
high=np.concatenate([s.high for s in space]),
|
||||
dtype=np.result_type(*[s.dtype for s in space]),
|
||||
low=np.concatenate([s.low for s in space_list]),
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Dict)
|
||||
def _flatten_space_dict(space):
|
||||
space = [flatten_space(s) for s in space.spaces.values()]
|
||||
def _flatten_space_dict(space: Dict) -> Box:
|
||||
space_list = [flatten_space(s) for s in space.spaces.values()]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space]),
|
||||
high=np.concatenate([s.high for s in space]),
|
||||
dtype=np.result_type(*[s.dtype for s in space]),
|
||||
low=np.concatenate([s.low for s in space_list]),
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
|
Reference in New Issue
Block a user