Files
Gymnasium/gym/spaces/dict.py

163 lines
5.5 KiB
Python
Raw Normal View History

from __future__ import annotations
2017-09-05 08:49:43 -07:00
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from typing import Dict as TypingDict
import numpy as np
from .space import Space
2017-09-05 08:49:43 -07:00
class Dict(Space[TypingDict[str, Space]], Mapping):
2017-09-05 08:49:43 -07:00
"""
A dictionary of simpler spaces.
2017-09-05 08:49:43 -07:00
Example usage::
self.observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
Example usage [nested]::
self.nested_observation_space = spaces.Dict({
'sensors': spaces.Dict({
'position': spaces.Box(low=-100, high=100, shape=(3,)),
'velocity': spaces.Box(low=-1, high=1, shape=(3,)),
'front_cam': spaces.Tuple((
spaces.Box(low=0, high=1, shape=(10, 10, 3)),
spaces.Box(low=0, high=1, shape=(10, 10, 3))
)),
'rear_cam': spaces.Box(low=0, high=1, shape=(10, 10, 3)),
}),
'ext_controller': spaces.MultiDiscrete((5, 2, 2)),
'inner_state':spaces.Dict({
'charge': spaces.Discrete(100),
'system_checks': spaces.MultiBinary(10),
'job_status': spaces.Dict({
'task': spaces.Discrete(5),
'progress': spaces.Box(low=0, high=100, shape=()),
})
})
})
2017-09-05 08:49:43 -07:00
"""
2021-07-29 02:26:34 +02:00
def __init__(
self,
spaces: dict[str, Space] | None = None,
seed: dict | int | None = None,
Reduces warnings produced by pytest from ~1500 to 13 (#2660) * Updated cartpole-v0 to v1 to prevent warning and added pytest.mark.filterwarnings for tests where warnings are unavoidable * Change np.bool to bool as numpy raises a warning and bool is the suggested solution * Seeding randint is deprecated in the future, integers is new solution * Fixed errors thrown when the video recorder is deleted but not closed * spaces.Box expects a floating array, updated all cases where this was not true and modified float32 to float64 as float array default to float64. Otherwise space.Box raises warning that dtype precision (float32) is lower than array precision (float64). * Added pytest.mark.filterwarnings to preventing the raising of an intended warning * Added comment to explain why a warning is raised that can't be prevented without version update to the environment * Added comment to explain why warning is raised * Changed values to float as expected by the box which default to float64 * Removed --forked from pytest as the pytest-forked project is no being maintained and was not raising warnings as expected * When AsyncVectorEnv has shared_memory=True then a ValueError is raised before _state is initialised. Therefore, on the destruction on the env an error is thrown in .close_extra as _state does not exist * Possible fix that was causing an error in test_call_async_vector_env by ensuring that pygame resources are released * Pygame throws an error with ALSA when closed, using a fix from PettingZoo (https://github.com/Farama-Foundation/PettingZoo/blob/master/pettingzoo/__init__.py). We use the dsp audiodriver to prevent this issue * Modification due to running pre-commit locally * Updated cartpole-v0 to v1 to prevent warning and added pytest.mark.filterwarnings for tests where warnings are unavoidable * Change np.bool to bool as numpy raises a warning and bool is the suggested solution * Seeding randint is deprecated in the future, integers is new solution * Fixed errors thrown when the video recorder is deleted but not closed * spaces.Box expects a floating array, updated all cases where this was not true and modified float32 to float64 as float array default to float64. Otherwise space.Box raises warning that dtype precision (float32) is lower than array precision (float64). * Added pytest.mark.filterwarnings to preventing the raising of an intended warning * Added comment to explain why a warning is raised that can't be prevented without version update to the environment * Added comment to explain why warning is raised * Changed values to float as expected by the box which default to float64 * Removed --forked from pytest as the pytest-forked project is no being maintained and was not raising warnings as expected * When AsyncVectorEnv has shared_memory=True then a ValueError is raised before _state is initialised. Therefore, on the destruction on the env an error is thrown in .close_extra as _state does not exist * Possible fix that was causing an error in test_call_async_vector_env by ensuring that pygame resources are released * Pygame throws an error with ALSA when closed, using a fix from PettingZoo (https://github.com/Farama-Foundation/PettingZoo/blob/master/pettingzoo/__init__.py). We use the dsp audiodriver to prevent this issue * Modification due to running pre-commit locally
2022-03-14 14:27:03 +00:00
**spaces_kwargs: Space,
):
2021-07-29 15:39:42 -04:00
assert (spaces is None) or (
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
if spaces is None:
spaces = spaces_kwargs
2018-08-27 19:28:32 +02:00
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
try:
spaces = OrderedDict(sorted(spaces.items()))
except TypeError: # raise when sort by different types of keys
spaces = OrderedDict(spaces.items())
if isinstance(spaces, Sequence):
spaces = OrderedDict(spaces)
assert isinstance(spaces, OrderedDict), "spaces must be a dictionary"
2017-09-05 08:49:43 -07:00
self.spaces = spaces
2019-03-25 02:06:55 +01:00
for space in spaces.values():
2021-07-29 15:39:42 -04:00
assert isinstance(
space, Space
), "Values of the dict should be instances of gym.Space"
super().__init__(
None, None, seed # type: ignore
2021-07-29 15:39:42 -04:00
) # None for shape and dtype, since it'll require special handling
def seed(self, seed: dict | int | None = None) -> list:
seeds = []
if isinstance(seed, dict):
for key, seed_key in zip(self.spaces, seed):
assert key == seed_key, print(
"Key value",
seed_key,
"in passed seed dict did not match key value",
key,
"in spaces Dict.",
)
seeds += self.spaces[key].seed(seed[seed_key])
elif isinstance(seed, int):
seeds = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)
for subspace, subseed in zip(self.spaces.values(), subseeds):
seeds.append(subspace.seed(int(subseed))[0])
elif seed is None:
for space in self.spaces.values():
seeds += space.seed(seed)
else:
raise TypeError("Passed seed not of an expected type: dict or int or None")
return seeds
2017-09-05 08:49:43 -07:00
def sample(self) -> dict:
2017-09-05 08:49:43 -07:00
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
def contains(self, x) -> bool:
2017-09-05 08:49:43 -07:00
if not isinstance(x, dict) or len(x) != len(self.spaces):
return False
for k, space in self.spaces.items():
if k not in x:
return False
if not space.contains(x[k]):
return False
return True
2019-04-19 23:19:07 +02:00
def __getitem__(self, key):
return self.spaces[key]
2021-07-29 02:26:34 +02:00
def __setitem__(self, key, value):
self.spaces[key] = value
def __iter__(self):
yield from self.spaces
2019-04-19 23:19:07 +02:00
def __len__(self) -> int:
2021-07-29 22:29:10 +02:00
return len(self.spaces)
def __repr__(self) -> str:
2021-07-29 15:39:42 -04:00
return (
"Dict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")"
)
2017-09-05 08:49:43 -07:00
def to_jsonable(self, sample_n: list) -> dict:
2017-09-05 08:49:43 -07:00
# serialize as dict-repr of vectors
2021-07-29 15:39:42 -04:00
return {
key: space.to_jsonable([sample[key] for sample in sample_n])
for key, space in self.spaces.items()
}
2017-09-05 08:49:43 -07:00
def from_jsonable(self, sample_n: dict[str, list]) -> list:
dict_of_list: dict[str, list] = {}
2017-09-05 08:49:43 -07:00
for key, space in self.spaces.items():
dict_of_list[key] = space.from_jsonable(sample_n[key])
ret = []
n_elements = len(next(iter(dict_of_list.values())))
for i in range(n_elements):
2017-09-05 08:49:43 -07:00
entry = {}
for key, value in dict_of_list.items():
entry[key] = value[i]
ret.append(entry)
return ret