Remove ordereddict in favour of python dict (#977)

This commit is contained in:
Mark Towers
2024-03-22 11:19:41 +00:00
committed by GitHub
parent a79e5d6e8a
commit 15d179087e
8 changed files with 54 additions and 81 deletions

View File

@@ -1,4 +1,3 @@
import collections
import os import os
import time import time
from typing import Dict, Optional from typing import Dict, Optional
@@ -27,13 +26,11 @@ def _import_osmesa(width, height):
return GLContext(width, height) return GLContext(width, height)
_ALL_RENDERERS = collections.OrderedDict( _ALL_RENDERERS = {
[ "glfw": _import_glfw,
("glfw", _import_glfw), "egl": _import_egl,
("egl", _import_egl), "osmesa": _import_osmesa,
("osmesa", _import_osmesa), }
]
)
class BaseRender: class BaseRender:

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import collections.abc import collections.abc
import typing import typing
from collections import OrderedDict
from typing import Any, KeysView, Sequence from typing import Any, KeysView, Sequence
import numpy as np import numpy as np
@@ -20,7 +19,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
>>> from gymnasium.spaces import Dict, Box, Discrete >>> from gymnasium.spaces import Dict, Box, Discrete
>>> observation_space = Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)}, seed=42) >>> observation_space = Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)}, seed=42)
>>> observation_space.sample() >>> observation_space.sample()
OrderedDict([('color', 0), ('position', array([-0.3991573 , 0.21649833], dtype=float32))]) {'color': 0, 'position': array([-0.3991573 , 0.21649833], dtype=float32)}
With a nested dict: With a nested dict:
@@ -67,23 +66,23 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above. **spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
""" """
# Convert the spaces into an OrderedDict # Convert the spaces into an OrderedDict
if isinstance(spaces, collections.abc.Mapping) and not isinstance( if isinstance(spaces, collections.abc.Mapping):
spaces, OrderedDict # for legacy reasons, we need to preserve the sorted dictionary items.
): # as this could matter for projects flatten the dictionary.
try: try:
spaces = OrderedDict(sorted(spaces.items())) spaces = dict(sorted(spaces.items()))
except TypeError: except TypeError:
# Incomparable types (e.g. `int` vs. `str`, or user-defined types) found. # Incomparable types (e.g. `int` vs. `str`, or user-defined types) found.
# The keys remain in the insertion order. # The keys remain in the insertion order.
spaces = OrderedDict(spaces.items()) spaces = dict(spaces.items())
elif isinstance(spaces, Sequence): elif isinstance(spaces, Sequence):
spaces = OrderedDict(spaces) spaces = dict(spaces)
elif spaces is None: elif spaces is None:
spaces = OrderedDict() spaces = dict()
else: else:
assert isinstance( raise TypeError(
spaces, OrderedDict f"Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: {type(spaces)}"
), f"Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: {type(spaces)}" )
# Add kwargs to spaces to allow both dictionary and keywords to be used # Add kwargs to spaces to allow both dictionary and keywords to be used
for key, space in spaces_kwargs.items(): for key, space in spaces_kwargs.items():
@@ -164,11 +163,9 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
assert ( assert (
mask.keys() == self.spaces.keys() mask.keys() == self.spaces.keys()
), f"Expect mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}" ), f"Expect mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
return OrderedDict( return {k: space.sample(mask=mask[k]) for k, space in self.spaces.items()}
[(k, space.sample(mask[k])) for k, space in self.spaces.items()]
)
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()]) return {k: space.sample() for k, space in self.spaces.items()}
def contains(self, x: Any) -> bool: def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
@@ -221,9 +218,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
for key, space in self.spaces.items() for key, space in self.spaces.items()
} }
def from_jsonable( def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
self, sample_n: dict[str, list[Any]]
) -> list[OrderedDict[str, Any]]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
dict_of_list: dict[str, list[Any]] = { dict_of_list: dict[str, list[Any]] = {
key: space.from_jsonable(sample_n[key]) key: space.from_jsonable(sample_n[key])
@@ -232,7 +227,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
n_elements = len(next(iter(dict_of_list.values()))) n_elements = len(next(iter(dict_of_list.values())))
result = [ result = [
OrderedDict({key: value[n] for key, value in dict_of_list.items()}) {key: value[n] for key, value in dict_of_list.items()}
for n in range(n_elements) for n in range(n_elements)
] ]
return result return result

View File

@@ -7,7 +7,6 @@ from __future__ import annotations
import operator as op import operator as op
import typing import typing
from collections import OrderedDict
from functools import reduce, singledispatch from functools import reduce, singledispatch
from typing import Any, TypeVar, Union, cast from typing import Any, TypeVar, Union, cast
@@ -201,7 +200,7 @@ def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[An
return np.concatenate( return np.concatenate(
[np.array(flatten(s, x[key])) for key, s in space.spaces.items()] [np.array(flatten(s, x[key])) for key, s in space.spaces.items()]
) )
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items()) return {key: flatten(s, x[key]) for key, s in space.spaces.items()}
@flatten.register(Graph) @flatten.register(Graph)
@@ -361,16 +360,15 @@ def _unflatten_dict(space: Dict, x: NDArray[Any] | dict[str, Any]) -> dict[str,
if space.is_np_flattenable: if space.is_np_flattenable:
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_) dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1])) list_flattened = np.split(x, np.cumsum(dims[:-1]))
return OrderedDict( return {
[ key: unflatten(s, flattened)
(key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items())
for flattened, (key, s) in zip(list_flattened, space.spaces.items()) }
]
)
assert isinstance( assert isinstance(
x, dict x, dict
), f"{space} is not numpy-flattenable. Thus, you should only unflatten dictionary for this space. Got a {type(x)}" ), f"{space} is not numpy-flattenable. Thus, you should only unflatten dictionary for this space. Got a {type(x)}"
return OrderedDict((key, unflatten(s, x[key])) for key, s in space.spaces.items()) return {key: unflatten(s, x[key]) for key, s in space.spaces.items()}
@unflatten.register(Graph) @unflatten.register(Graph)
@@ -532,9 +530,7 @@ def _flatten_space_dict(space: Dict) -> Box | Dict:
dtype=np.result_type(*[s.dtype for s in space_list]), dtype=np.result_type(*[s.dtype for s in space_list]),
) )
return Dict( return Dict(
spaces=OrderedDict( spaces={key: flatten_space(space) for key, space in space.spaces.items()}
(key, flatten_space(space)) for key, space in space.spaces.items()
)
) )

View File

@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import multiprocessing as mp import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool from ctypes import c_bool
from functools import singledispatch from functools import singledispatch
from typing import Any from typing import Any
@@ -81,12 +80,10 @@ def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
@create_shared_memory.register(Dict) @create_shared_memory.register(Dict)
def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp): def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
return OrderedDict( return {
[ key: create_shared_memory(subspace, n=n, ctx=ctx)
(key, create_shared_memory(subspace, n=n, ctx=ctx)) for (key, subspace) in space.spaces.items()
for (key, subspace) in space.spaces.items() }
]
)
@create_shared_memory.register(Text) @create_shared_memory.register(Text)
@@ -163,15 +160,12 @@ def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
@read_from_shared_memory.register(Dict) @read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1): def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
subspace_samples = OrderedDict( subspace_samples = {
[ key: read_from_shared_memory(subspace, shared_memory[key], n=n)
(key, read_from_shared_memory(subspace, shared_memory[key], n=n)) for (key, subspace) in space.spaces.items()
for (key, subspace) in space.spaces.items() }
]
)
return tuple( return tuple(
OrderedDict({key: subspace_samples[key][i] for key in space.keys()}) {key: subspace_samples[key][i] for key in space.keys()} for i in range(n)
for i in range(n)
) )

View File

@@ -7,7 +7,6 @@
""" """
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from functools import singledispatch from functools import singledispatch
from typing import Any, Iterable, Iterator from typing import Any, Iterable, Iterator
@@ -163,9 +162,9 @@ def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
>>> items = space.sample() >>> items = space.sample()
>>> it = iterate(space, items) >>> it = iterate(space, items)
>>> next(it) >>> next(it)
OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))]) {'position': array([0.77395606, 0.43887845, 0.85859793], dtype=float32), 'velocity': array([0.77395606, 0.43887845], dtype=float32)}
>>> next(it) >>> next(it)
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))]) {'position': array([0.697368 , 0.09417735, 0.97562236], dtype=float32), 'velocity': array([0.85859793, 0.697368 ], dtype=float32)}
>>> next(it) >>> next(it)
Traceback (most recent call last): Traceback (most recent call last):
... ...
@@ -226,7 +225,7 @@ def _iterate_dict(space: Dict, items: dict[str, Any]):
] ]
) )
for item in zip(*values): for item in zip(*values):
yield OrderedDict({key: value for key, value in zip(keys, item)}) yield {key: value for key, value in zip(keys, item)}
@singledispatch @singledispatch
@@ -287,12 +286,10 @@ def _concatenate_tuple(
def _concatenate_dict( def _concatenate_dict(
space: Dict, items: Iterable, out: dict[str, Any] space: Dict, items: Iterable, out: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
return OrderedDict( return {
{ key: concatenate(subspace, [item[key] for item in items], out[key])
key: concatenate(subspace, [item[key] for item in items], out[key]) for key, subspace in space.items()
for key, subspace in space.items() }
}
)
@concatenate.register(Graph) @concatenate.register(Graph)
@@ -330,9 +327,9 @@ def create_empty_array(
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)}) ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros) >>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.], {'position': array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.], [0., 0., 0.]], dtype=float32), 'velocity': array([[0., 0.],
[0., 0.]], dtype=float32))]) [0., 0.]], dtype=float32)}
""" """
raise TypeError( raise TypeError(
f"The space provided to `create_empty_array` is not a gymnasium Space instance, type: {type(space)}, {space}" f"The space provided to `create_empty_array` is not a gymnasium Space instance, type: {type(space)}, {space}"
@@ -356,12 +353,9 @@ def _create_empty_array_tuple(space: Tuple, n: int = 1, fn=np.zeros) -> tuple[An
@create_empty_array.register(Dict) @create_empty_array.register(Dict)
def _create_empty_array_dict(space: Dict, n: int = 1, fn=np.zeros) -> dict[str, Any]: def _create_empty_array_dict(space: Dict, n: int = 1, fn=np.zeros) -> dict[str, Any]:
return OrderedDict( return {
{ key: create_empty_array(subspace, n=n, fn=fn) for key, subspace in space.items()
key: create_empty_array(subspace, n=n, fn=fn) }
for key, subspace in space.items()
}
)
@create_empty_array.register(Graph) @create_empty_array.register(Graph)

View File

@@ -1,5 +1,4 @@
"""Utility functions for the wrappers.""" """Utility functions for the wrappers."""
from collections import OrderedDict
from functools import singledispatch from functools import singledispatch
import numpy as np import numpy as np
@@ -119,9 +118,7 @@ def _create_tuple_zero_array(space: Tuple):
@create_zero_array.register(Dict) @create_zero_array.register(Dict)
def _create_dict_zero_array(space: Dict): def _create_dict_zero_array(space: Dict):
return OrderedDict( return {key: create_zero_array(subspace) for key, subspace in space.spaces.items()}
{key: create_zero_array(subspace) for key, subspace in space.spaces.items()}
)
@create_zero_array.register(Sequence) @create_zero_array.register(Sequence)

View File

@@ -189,10 +189,10 @@ class FilterObservation(VectorizeTransformObservation):
>>> obs, info = envs.reset(seed=123) >>> obs, info = envs.reset(seed=123)
>>> envs.close() >>> envs.close()
>>> obs >>> obs
OrderedDict([('obs', array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], {'obs': array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
[ 0.02852531, 0.02858594, 0.0469136 , 0.02480598], [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598],
[ 0.03517495, -0.000635 , -0.01098382, -0.03203924]], [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]],
dtype=float32))]) dtype=float32)}
""" """
def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]): def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):

View File

@@ -10,7 +10,7 @@ from gymnasium.spaces import Box, Dict, Discrete
def test_dict_init(): def test_dict_init():
with pytest.raises( with pytest.raises(
AssertionError, TypeError,
match=r"^Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: ", match=r"^Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: ",
): ):
Dict(Discrete(2)) Dict(Discrete(2))