Fix OneOf shared memory and add pytest.skip to tests (#999)

This commit is contained in:
Mark Towers
2024-04-06 13:20:10 +01:00
committed by GitHub
parent 9c812af180
commit 0fe94efa26
10 changed files with 50 additions and 27 deletions

View File

@@ -43,6 +43,7 @@ class OneOf(Space[Any]):
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
assert isinstance(spaces, Iterable), f"{spaces} is not an iterable"
self.spaces = tuple(spaces)
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
for space in self.spaces:
@@ -105,7 +106,7 @@ class OneOf(Space[Any]):
Returns:
Tuple of the subspace's samples
"""
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
subspace_idx = self.np_random.integers(0, len(self.spaces), dtype=np.int64)
subspace = self.spaces[subspace_idx]
if mask is not None:
assert isinstance(
@@ -121,9 +122,14 @@ class OneOf(Space[Any]):
def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
(idx, value) = x
return isinstance(x, tuple) and self.spaces[idx].contains(value)
# subspace_idx, subspace_value = x
return (
isinstance(x, tuple)
and len(x) == 2
and isinstance(x[0], (np.int64, int))
and 0 <= x[0] < len(self.spaces)
and self.spaces[x[0]].contains(x[1])
)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
@@ -141,7 +147,10 @@ class OneOf(Space[Any]):
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
(
np.int64(space_idx),
self.spaces[space_idx].from_jsonable([jsonable_sample])[0],
)
for space_idx, jsonable_sample in sample_n
]

View File

@@ -421,7 +421,7 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
@unflatten.register(OneOf)
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
idx = int(x[0])
idx = np.int64(x[0])
sub_space = space.spaces[idx]
original_size = flatdim(sub_space)

View File

@@ -93,14 +93,16 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
@create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)
return (ctx.Array(np.dtype(np.int64).char, n),) + tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
raise TypeError(
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
f"As {space} has a dynamic shape so its not possible to make a static shared memory."
)
@@ -193,14 +195,15 @@ def _read_text_from_shared_memory(
def _read_one_of_from_shared_memory(
space: OneOf, shared_memory, n: int = 1
) -> tuple[Any, ...]:
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
)
return tuple(
(index, sample[index])
for index, sample in zip(sample_indexes, subspace_samples)
(sample_index, subspace_samples[sample_index][index])
for index, sample_index in enumerate(sample_indexes)
)
@@ -279,10 +282,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
@write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
space: OneOf, index: int, values: tuple[int, Any], shared_memory
):
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
np.copyto(destination[index : index + 1], values[0])
subspace_idx, space_value = values
for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
write_to_shared_memory(subspace, index, value, memory)
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
np.copyto(destination[index : index + 1], subspace_idx)
# only the subspace's memory is updated with the sample value, ignoring the other memories as data might not match
write_to_shared_memory(
space.spaces[subspace_idx], index, space_value, shared_memory[1 + subspace_idx]
)

View File

@@ -436,7 +436,7 @@ def test_observation_structure(env_name: str, version: str):
env = gym.make(f"{env_name}-{version}").unwrapped
assert isinstance(env, MujocoEnv)
if not hasattr(env, "observation_structure"):
return
pytest.skip("Environment doesn't have an `observation_structure` attribute")
obs_struct = env.observation_structure

View File

@@ -71,7 +71,8 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return
assert env.spec is not None
pytest.skip(f"Skipping jax-based environment ({env.spec.id})")
assert isinstance(env.action_space, spaces.Discrete)
upper_bound = env.action_space.start + env.action_space.n - 1
@@ -106,7 +107,8 @@ def test_box_actions_out_of_bound(env: gym.Env):
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return
assert env.spec is not None
pytest.skip(f"Skipping jax-based environment ({env.spec.id})")
env.reset(seed=42)

View File

@@ -63,7 +63,7 @@ def test_all_env_passive_env_checker(spec):
for warning in caught_warnings:
if not passive_check_pattern.search(str(warning.message)):
print(f"Unexpected warning: {warning.message}")
raise ValueError(f"Unexpected warning: {warning.message}")
# Note that this precludes running this test in multiple threads.
@@ -90,7 +90,7 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
"""
# Don't check rollout equality if it's a nondeterministic environment.
if env_spec.nondeterministic is True:
return
pytest.skip(f"Skipping {env_spec.id} as it is non-deterministic")
env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True)

View File

@@ -62,6 +62,10 @@ def test_oneof_contains():
assert (0, np.array([0.5], dtype=np.float32)) in space
assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space
assert (np.int64(0), np.array([0.5], dtype=np.float32)) in space
assert (np.int32(0), np.array([0.5], dtype=np.float32)) not in space
def test_bad_oneof_seed():
space = OneOf([Box(0, 1), Box(0, 1)])

View File

@@ -118,7 +118,7 @@ def test_example_wrapper(example_env):
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
),
):
print(wrapper_env.access_hidden_np_random())
_ = wrapper_env.access_hidden_np_random()
class ExampleRewardWrapper(RewardWrapper):

View File

@@ -33,8 +33,12 @@ def test_shared_memory_create_read_write(space, num, ctx):
try:
shared_memory = create_shared_memory(space, n=num, ctx=ctx)
except TypeError:
return
except TypeError as err:
assert (
"has a dynamic shape so its not possible to make a static shared memory."
in str(err)
)
pytest.skip("Skipping space with dynamic shape")
for i, sample in enumerate(samples):
write_to_shared_memory(space, i, sample, shared_memory)

View File

@@ -96,9 +96,6 @@ class ExampleNamedTuple(NamedTuple):
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
print(f"{value=}")
print(f"{torch_to_jax(value)=}")
print(f"{jax_to_torch(torch_to_jax(value))=}")
roundtripped_value = jax_to_torch(torch_to_jax(value))
assert torch_data_equivalence(roundtripped_value, expected_value)