diff --git a/gymnasium/spaces/oneof.py b/gymnasium/spaces/oneof.py index d88f0b130..be158e614 100644 --- a/gymnasium/spaces/oneof.py +++ b/gymnasium/spaces/oneof.py @@ -43,6 +43,7 @@ class OneOf(Space[Any]): spaces (Iterable[Space]): The spaces that are involved in the cartesian product. seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling. """ + assert isinstance(spaces, Iterable), f"{spaces} is not an iterable" self.spaces = tuple(spaces) assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported." for space in self.spaces: @@ -105,7 +106,7 @@ class OneOf(Space[Any]): Returns: Tuple of the subspace's samples """ - subspace_idx = int(self.np_random.integers(0, len(self.spaces))) + subspace_idx = self.np_random.integers(0, len(self.spaces), dtype=np.int64) subspace = self.spaces[subspace_idx] if mask is not None: assert isinstance( @@ -121,9 +122,14 @@ class OneOf(Space[Any]): def contains(self, x: tuple[int, Any]) -> bool: """Return boolean specifying if x is a valid member of this space.""" - (idx, value) = x - - return isinstance(x, tuple) and self.spaces[idx].contains(value) + # subspace_idx, subspace_value = x + return ( + isinstance(x, tuple) + and len(x) == 2 + and isinstance(x[0], (np.int64, int)) + and 0 <= x[0] < len(self.spaces) + and self.spaces[x[0]].contains(x[1]) + ) def __repr__(self) -> str: """Gives a string representation of this space.""" @@ -141,7 +147,10 @@ class OneOf(Space[Any]): def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]: """Convert a JSONable data type to a batch of samples from this space.""" return [ - (space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0]) + ( + np.int64(space_idx), + self.spaces[space_idx].from_jsonable([jsonable_sample])[0], + ) for space_idx, jsonable_sample in sample_n ] diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 8f737af28..ed53ca566 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -421,7 +421,7 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] @unflatten.register(OneOf) def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]: - idx = int(x[0]) + idx = np.int64(x[0]) sub_space = space.spaces[idx] original_size = flatdim(sub_space) diff --git a/gymnasium/vector/utils/shared_memory.py b/gymnasium/vector/utils/shared_memory.py index e30444862..66494718c 100644 --- a/gymnasium/vector/utils/shared_memory.py +++ b/gymnasium/vector/utils/shared_memory.py @@ -93,14 +93,16 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp): @create_shared_memory.register(OneOf) def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp): - return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space) + return (ctx.Array(np.dtype(np.int64).char, n),) + tuple( + create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces + ) @create_shared_memory.register(Graph) @create_shared_memory.register(Sequence) def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp): raise TypeError( - f"As {space} has a dynamic shape then it is not possible to make a static shared memory." + f"As {space} has a dynamic shape so its not possible to make a static shared memory." ) @@ -193,14 +195,15 @@ def _read_text_from_shared_memory( def _read_one_of_from_shared_memory( space: OneOf, shared_memory, n: int = 1 ) -> tuple[Any, ...]: - sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype) + sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64) + subspace_samples = tuple( read_from_shared_memory(subspace, memory, n=n) for (memory, subspace) in zip(shared_memory[1:], space.spaces) ) return tuple( - (index, sample[index]) - for index, sample in zip(sample_indexes, subspace_samples) + (sample_index, subspace_samples[sample_index][index]) + for index, sample_index in enumerate(sample_indexes) ) @@ -279,10 +282,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me @write_to_shared_memory.register(OneOf) def _write_oneof_to_shared_memory( - space: OneOf, index: int, values: tuple[Any, ...], shared_memory + space: OneOf, index: int, values: tuple[int, Any], shared_memory ): - destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32) - np.copyto(destination[index : index + 1], values[0]) + subspace_idx, space_value = values - for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces): - write_to_shared_memory(subspace, index, value, memory) + destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64) + np.copyto(destination[index : index + 1], subspace_idx) + + # only the subspace's memory is updated with the sample value, ignoring the other memories as data might not match + write_to_shared_memory( + space.spaces[subspace_idx], index, space_value, shared_memory[1 + subspace_idx] + ) diff --git a/tests/envs/mujoco/test_mujoco_v5.py b/tests/envs/mujoco/test_mujoco_v5.py index 44ba9a6b2..02149b360 100644 --- a/tests/envs/mujoco/test_mujoco_v5.py +++ b/tests/envs/mujoco/test_mujoco_v5.py @@ -436,7 +436,7 @@ def test_observation_structure(env_name: str, version: str): env = gym.make(f"{env_name}-{version}").unwrapped assert isinstance(env, MujocoEnv) if not hasattr(env, "observation_structure"): - return + pytest.skip("Environment doesn't have an `observation_structure` attribute") obs_struct = env.observation_structure diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index afe230e09..d42ad0eb7 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -71,7 +71,8 @@ def test_discrete_actions_out_of_bound(env: gym.Env): env (gym.Env): the gymnasium environment """ if env.metadata.get("jax", False): - return + assert env.spec is not None + pytest.skip(f"Skipping jax-based environment ({env.spec.id})") assert isinstance(env.action_space, spaces.Discrete) upper_bound = env.action_space.start + env.action_space.n - 1 @@ -106,7 +107,8 @@ def test_box_actions_out_of_bound(env: gym.Env): env (gym.Env): the gymnasium environment """ if env.metadata.get("jax", False): - return + assert env.spec is not None + pytest.skip(f"Skipping jax-based environment ({env.spec.id})") env.reset(seed=42) diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index f6b48b135..0855871a0 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -63,7 +63,7 @@ def test_all_env_passive_env_checker(spec): for warning in caught_warnings: if not passive_check_pattern.search(str(warning.message)): - print(f"Unexpected warning: {warning.message}") + raise ValueError(f"Unexpected warning: {warning.message}") # Note that this precludes running this test in multiple threads. @@ -90,7 +90,7 @@ def test_env_determinism_rollout(env_spec: EnvSpec): """ # Don't check rollout equality if it's a nondeterministic environment. if env_spec.nondeterministic is True: - return + pytest.skip(f"Skipping {env_spec.id} as it is non-deterministic") env_1 = env_spec.make(disable_env_checker=True) env_2 = env_spec.make(disable_env_checker=True) diff --git a/tests/spaces/test_oneof.py b/tests/spaces/test_oneof.py index 61088768f..f4a879bf9 100644 --- a/tests/spaces/test_oneof.py +++ b/tests/spaces/test_oneof.py @@ -62,6 +62,10 @@ def test_oneof_contains(): assert (0, np.array([0.5], dtype=np.float32)) in space assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space + assert (np.int64(0), np.array([0.5], dtype=np.float32)) in space + + assert (np.int32(0), np.array([0.5], dtype=np.float32)) not in space + def test_bad_oneof_seed(): space = OneOf([Box(0, 1), Box(0, 1)]) diff --git a/tests/test_core.py b/tests/test_core.py index a91d3d6a7..374bf4bb9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -118,7 +118,7 @@ def test_example_wrapper(example_env): "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ), ): - print(wrapper_env.access_hidden_np_random()) + _ = wrapper_env.access_hidden_np_random() class ExampleRewardWrapper(RewardWrapper): diff --git a/tests/vector/utils/test_shared_memory.py b/tests/vector/utils/test_shared_memory.py index 9b3118c64..3edf6a45d 100644 --- a/tests/vector/utils/test_shared_memory.py +++ b/tests/vector/utils/test_shared_memory.py @@ -33,8 +33,12 @@ def test_shared_memory_create_read_write(space, num, ctx): try: shared_memory = create_shared_memory(space, n=num, ctx=ctx) - except TypeError: - return + except TypeError as err: + assert ( + "has a dynamic shape so its not possible to make a static shared memory." + in str(err) + ) + pytest.skip("Skipping space with dynamic shape") for i, sample in enumerate(samples): write_to_shared_memory(space, i, sample, shared_memory) diff --git a/tests/wrappers/test_jax_to_torch.py b/tests/wrappers/test_jax_to_torch.py index ef2af4900..54e52e5be 100644 --- a/tests/wrappers/test_jax_to_torch.py +++ b/tests/wrappers/test_jax_to_torch.py @@ -96,9 +96,6 @@ class ExampleNamedTuple(NamedTuple): ) def test_roundtripping(value, expected_value): """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" - print(f"{value=}") - print(f"{torch_to_jax(value)=}") - print(f"{jax_to_torch(torch_to_jax(value))=}") roundtripped_value = jax_to_torch(torch_to_jax(value)) assert torch_data_equivalence(roundtripped_value, expected_value)