Fix OneOf shared memory and add pytest.skip to tests (#999)

This commit is contained in:
Mark Towers
2024-04-06 13:20:10 +01:00
committed by GitHub
parent 9c812af180
commit 0fe94efa26
10 changed files with 50 additions and 27 deletions

View File

@@ -43,6 +43,7 @@ class OneOf(Space[Any]):
spaces (Iterable[Space]): The spaces that are involved in the cartesian product. spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling. seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
""" """
assert isinstance(spaces, Iterable), f"{spaces} is not an iterable"
self.spaces = tuple(spaces) self.spaces = tuple(spaces)
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported." assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
for space in self.spaces: for space in self.spaces:
@@ -105,7 +106,7 @@ class OneOf(Space[Any]):
Returns: Returns:
Tuple of the subspace's samples Tuple of the subspace's samples
""" """
subspace_idx = int(self.np_random.integers(0, len(self.spaces))) subspace_idx = self.np_random.integers(0, len(self.spaces), dtype=np.int64)
subspace = self.spaces[subspace_idx] subspace = self.spaces[subspace_idx]
if mask is not None: if mask is not None:
assert isinstance( assert isinstance(
@@ -121,9 +122,14 @@ class OneOf(Space[Any]):
def contains(self, x: tuple[int, Any]) -> bool: def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
(idx, value) = x # subspace_idx, subspace_value = x
return (
return isinstance(x, tuple) and self.spaces[idx].contains(value) isinstance(x, tuple)
and len(x) == 2
and isinstance(x[0], (np.int64, int))
and 0 <= x[0] < len(self.spaces)
and self.spaces[x[0]].contains(x[1])
)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Gives a string representation of this space.""" """Gives a string representation of this space."""
@@ -141,7 +147,10 @@ class OneOf(Space[Any]):
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]: def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
return [ return [
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0]) (
np.int64(space_idx),
self.spaces[space_idx].from_jsonable([jsonable_sample])[0],
)
for space_idx, jsonable_sample in sample_n for space_idx, jsonable_sample in sample_n
] ]

View File

@@ -421,7 +421,7 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
@unflatten.register(OneOf) @unflatten.register(OneOf)
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]: def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
idx = int(x[0]) idx = np.int64(x[0])
sub_space = space.spaces[idx] sub_space = space.spaces[idx]
original_size = flatdim(sub_space) original_size = flatdim(sub_space)

View File

@@ -93,14 +93,16 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
@create_shared_memory.register(OneOf) @create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp): def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space) return (ctx.Array(np.dtype(np.int64).char, n),) + tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Graph) @create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence) @create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp): def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
raise TypeError( raise TypeError(
f"As {space} has a dynamic shape then it is not possible to make a static shared memory." f"As {space} has a dynamic shape so its not possible to make a static shared memory."
) )
@@ -193,14 +195,15 @@ def _read_text_from_shared_memory(
def _read_one_of_from_shared_memory( def _read_one_of_from_shared_memory(
space: OneOf, shared_memory, n: int = 1 space: OneOf, shared_memory, n: int = 1
) -> tuple[Any, ...]: ) -> tuple[Any, ...]:
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype) sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
subspace_samples = tuple( subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n) read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces) for (memory, subspace) in zip(shared_memory[1:], space.spaces)
) )
return tuple( return tuple(
(index, sample[index]) (sample_index, subspace_samples[sample_index][index])
for index, sample in zip(sample_indexes, subspace_samples) for index, sample_index in enumerate(sample_indexes)
) )
@@ -279,10 +282,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
@write_to_shared_memory.register(OneOf) @write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory( def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[Any, ...], shared_memory space: OneOf, index: int, values: tuple[int, Any], shared_memory
): ):
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32) subspace_idx, space_value = values
np.copyto(destination[index : index + 1], values[0])
for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces): destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
write_to_shared_memory(subspace, index, value, memory) np.copyto(destination[index : index + 1], subspace_idx)
# only the subspace's memory is updated with the sample value, ignoring the other memories as data might not match
write_to_shared_memory(
space.spaces[subspace_idx], index, space_value, shared_memory[1 + subspace_idx]
)

View File

@@ -436,7 +436,7 @@ def test_observation_structure(env_name: str, version: str):
env = gym.make(f"{env_name}-{version}").unwrapped env = gym.make(f"{env_name}-{version}").unwrapped
assert isinstance(env, MujocoEnv) assert isinstance(env, MujocoEnv)
if not hasattr(env, "observation_structure"): if not hasattr(env, "observation_structure"):
return pytest.skip("Environment doesn't have an `observation_structure` attribute")
obs_struct = env.observation_structure obs_struct = env.observation_structure

View File

@@ -71,7 +71,8 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
env (gym.Env): the gymnasium environment env (gym.Env): the gymnasium environment
""" """
if env.metadata.get("jax", False): if env.metadata.get("jax", False):
return assert env.spec is not None
pytest.skip(f"Skipping jax-based environment ({env.spec.id})")
assert isinstance(env.action_space, spaces.Discrete) assert isinstance(env.action_space, spaces.Discrete)
upper_bound = env.action_space.start + env.action_space.n - 1 upper_bound = env.action_space.start + env.action_space.n - 1
@@ -106,7 +107,8 @@ def test_box_actions_out_of_bound(env: gym.Env):
env (gym.Env): the gymnasium environment env (gym.Env): the gymnasium environment
""" """
if env.metadata.get("jax", False): if env.metadata.get("jax", False):
return assert env.spec is not None
pytest.skip(f"Skipping jax-based environment ({env.spec.id})")
env.reset(seed=42) env.reset(seed=42)

View File

@@ -63,7 +63,7 @@ def test_all_env_passive_env_checker(spec):
for warning in caught_warnings: for warning in caught_warnings:
if not passive_check_pattern.search(str(warning.message)): if not passive_check_pattern.search(str(warning.message)):
print(f"Unexpected warning: {warning.message}") raise ValueError(f"Unexpected warning: {warning.message}")
# Note that this precludes running this test in multiple threads. # Note that this precludes running this test in multiple threads.
@@ -90,7 +90,7 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
""" """
# Don't check rollout equality if it's a nondeterministic environment. # Don't check rollout equality if it's a nondeterministic environment.
if env_spec.nondeterministic is True: if env_spec.nondeterministic is True:
return pytest.skip(f"Skipping {env_spec.id} as it is non-deterministic")
env_1 = env_spec.make(disable_env_checker=True) env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True) env_2 = env_spec.make(disable_env_checker=True)

View File

@@ -62,6 +62,10 @@ def test_oneof_contains():
assert (0, np.array([0.5], dtype=np.float32)) in space assert (0, np.array([0.5], dtype=np.float32)) in space
assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space
assert (np.int64(0), np.array([0.5], dtype=np.float32)) in space
assert (np.int32(0), np.array([0.5], dtype=np.float32)) not in space
def test_bad_oneof_seed(): def test_bad_oneof_seed():
space = OneOf([Box(0, 1), Box(0, 1)]) space = OneOf([Box(0, 1), Box(0, 1)])

View File

@@ -118,7 +118,7 @@ def test_example_wrapper(example_env):
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
), ),
): ):
print(wrapper_env.access_hidden_np_random()) _ = wrapper_env.access_hidden_np_random()
class ExampleRewardWrapper(RewardWrapper): class ExampleRewardWrapper(RewardWrapper):

View File

@@ -33,8 +33,12 @@ def test_shared_memory_create_read_write(space, num, ctx):
try: try:
shared_memory = create_shared_memory(space, n=num, ctx=ctx) shared_memory = create_shared_memory(space, n=num, ctx=ctx)
except TypeError: except TypeError as err:
return assert (
"has a dynamic shape so its not possible to make a static shared memory."
in str(err)
)
pytest.skip("Skipping space with dynamic shape")
for i, sample in enumerate(samples): for i, sample in enumerate(samples):
write_to_shared_memory(space, i, sample, shared_memory) write_to_shared_memory(space, i, sample, shared_memory)

View File

@@ -96,9 +96,6 @@ class ExampleNamedTuple(NamedTuple):
) )
def test_roundtripping(value, expected_value): def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
print(f"{value=}")
print(f"{torch_to_jax(value)=}")
print(f"{jax_to_torch(torch_to_jax(value))=}")
roundtripped_value = jax_to_torch(torch_to_jax(value)) roundtripped_value = jax_to_torch(torch_to_jax(value))
assert torch_data_equivalence(roundtripped_value, expected_value) assert torch_data_equivalence(roundtripped_value, expected_value)