mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-26 00:07:41 +00:00
Fix OneOf
shared memory and add pytest.skip
to tests (#999)
This commit is contained in:
@@ -43,6 +43,7 @@ class OneOf(Space[Any]):
|
||||
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
|
||||
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
|
||||
"""
|
||||
assert isinstance(spaces, Iterable), f"{spaces} is not an iterable"
|
||||
self.spaces = tuple(spaces)
|
||||
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
|
||||
for space in self.spaces:
|
||||
@@ -105,7 +106,7 @@ class OneOf(Space[Any]):
|
||||
Returns:
|
||||
Tuple of the subspace's samples
|
||||
"""
|
||||
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
|
||||
subspace_idx = self.np_random.integers(0, len(self.spaces), dtype=np.int64)
|
||||
subspace = self.spaces[subspace_idx]
|
||||
if mask is not None:
|
||||
assert isinstance(
|
||||
@@ -121,9 +122,14 @@ class OneOf(Space[Any]):
|
||||
|
||||
def contains(self, x: tuple[int, Any]) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
(idx, value) = x
|
||||
|
||||
return isinstance(x, tuple) and self.spaces[idx].contains(value)
|
||||
# subspace_idx, subspace_value = x
|
||||
return (
|
||||
isinstance(x, tuple)
|
||||
and len(x) == 2
|
||||
and isinstance(x[0], (np.int64, int))
|
||||
and 0 <= x[0] < len(self.spaces)
|
||||
and self.spaces[x[0]].contains(x[1])
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
@@ -141,7 +147,10 @@ class OneOf(Space[Any]):
|
||||
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [
|
||||
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
|
||||
(
|
||||
np.int64(space_idx),
|
||||
self.spaces[space_idx].from_jsonable([jsonable_sample])[0],
|
||||
)
|
||||
for space_idx, jsonable_sample in sample_n
|
||||
]
|
||||
|
||||
|
@@ -421,7 +421,7 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
|
||||
|
||||
@unflatten.register(OneOf)
|
||||
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
|
||||
idx = int(x[0])
|
||||
idx = np.int64(x[0])
|
||||
sub_space = space.spaces[idx]
|
||||
|
||||
original_size = flatdim(sub_space)
|
||||
|
@@ -93,14 +93,16 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
|
||||
|
||||
@create_shared_memory.register(OneOf)
|
||||
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
|
||||
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)
|
||||
return (ctx.Array(np.dtype(np.int64).char, n),) + tuple(
|
||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
||||
)
|
||||
|
||||
|
||||
@create_shared_memory.register(Graph)
|
||||
@create_shared_memory.register(Sequence)
|
||||
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
|
||||
raise TypeError(
|
||||
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
|
||||
f"As {space} has a dynamic shape so its not possible to make a static shared memory."
|
||||
)
|
||||
|
||||
|
||||
@@ -193,14 +195,15 @@ def _read_text_from_shared_memory(
|
||||
def _read_one_of_from_shared_memory(
|
||||
space: OneOf, shared_memory, n: int = 1
|
||||
) -> tuple[Any, ...]:
|
||||
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
|
||||
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
|
||||
|
||||
subspace_samples = tuple(
|
||||
read_from_shared_memory(subspace, memory, n=n)
|
||||
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
|
||||
)
|
||||
return tuple(
|
||||
(index, sample[index])
|
||||
for index, sample in zip(sample_indexes, subspace_samples)
|
||||
(sample_index, subspace_samples[sample_index][index])
|
||||
for index, sample_index in enumerate(sample_indexes)
|
||||
)
|
||||
|
||||
|
||||
@@ -279,10 +282,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
|
||||
|
||||
@write_to_shared_memory.register(OneOf)
|
||||
def _write_oneof_to_shared_memory(
|
||||
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
|
||||
space: OneOf, index: int, values: tuple[int, Any], shared_memory
|
||||
):
|
||||
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
|
||||
np.copyto(destination[index : index + 1], values[0])
|
||||
subspace_idx, space_value = values
|
||||
|
||||
for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
|
||||
write_to_shared_memory(subspace, index, value, memory)
|
||||
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
|
||||
np.copyto(destination[index : index + 1], subspace_idx)
|
||||
|
||||
# only the subspace's memory is updated with the sample value, ignoring the other memories as data might not match
|
||||
write_to_shared_memory(
|
||||
space.spaces[subspace_idx], index, space_value, shared_memory[1 + subspace_idx]
|
||||
)
|
||||
|
@@ -436,7 +436,7 @@ def test_observation_structure(env_name: str, version: str):
|
||||
env = gym.make(f"{env_name}-{version}").unwrapped
|
||||
assert isinstance(env, MujocoEnv)
|
||||
if not hasattr(env, "observation_structure"):
|
||||
return
|
||||
pytest.skip("Environment doesn't have an `observation_structure` attribute")
|
||||
|
||||
obs_struct = env.observation_structure
|
||||
|
||||
|
@@ -71,7 +71,8 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
|
||||
env (gym.Env): the gymnasium environment
|
||||
"""
|
||||
if env.metadata.get("jax", False):
|
||||
return
|
||||
assert env.spec is not None
|
||||
pytest.skip(f"Skipping jax-based environment ({env.spec.id})")
|
||||
|
||||
assert isinstance(env.action_space, spaces.Discrete)
|
||||
upper_bound = env.action_space.start + env.action_space.n - 1
|
||||
@@ -106,7 +107,8 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
||||
env (gym.Env): the gymnasium environment
|
||||
"""
|
||||
if env.metadata.get("jax", False):
|
||||
return
|
||||
assert env.spec is not None
|
||||
pytest.skip(f"Skipping jax-based environment ({env.spec.id})")
|
||||
|
||||
env.reset(seed=42)
|
||||
|
||||
|
@@ -63,7 +63,7 @@ def test_all_env_passive_env_checker(spec):
|
||||
|
||||
for warning in caught_warnings:
|
||||
if not passive_check_pattern.search(str(warning.message)):
|
||||
print(f"Unexpected warning: {warning.message}")
|
||||
raise ValueError(f"Unexpected warning: {warning.message}")
|
||||
|
||||
|
||||
# Note that this precludes running this test in multiple threads.
|
||||
@@ -90,7 +90,7 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
||||
"""
|
||||
# Don't check rollout equality if it's a nondeterministic environment.
|
||||
if env_spec.nondeterministic is True:
|
||||
return
|
||||
pytest.skip(f"Skipping {env_spec.id} as it is non-deterministic")
|
||||
|
||||
env_1 = env_spec.make(disable_env_checker=True)
|
||||
env_2 = env_spec.make(disable_env_checker=True)
|
||||
|
@@ -62,6 +62,10 @@ def test_oneof_contains():
|
||||
assert (0, np.array([0.5], dtype=np.float32)) in space
|
||||
assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space
|
||||
|
||||
assert (np.int64(0), np.array([0.5], dtype=np.float32)) in space
|
||||
|
||||
assert (np.int32(0), np.array([0.5], dtype=np.float32)) not in space
|
||||
|
||||
|
||||
def test_bad_oneof_seed():
|
||||
space = OneOf([Box(0, 1), Box(0, 1)])
|
||||
|
@@ -118,7 +118,7 @@ def test_example_wrapper(example_env):
|
||||
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
|
||||
),
|
||||
):
|
||||
print(wrapper_env.access_hidden_np_random())
|
||||
_ = wrapper_env.access_hidden_np_random()
|
||||
|
||||
|
||||
class ExampleRewardWrapper(RewardWrapper):
|
||||
|
@@ -33,8 +33,12 @@ def test_shared_memory_create_read_write(space, num, ctx):
|
||||
|
||||
try:
|
||||
shared_memory = create_shared_memory(space, n=num, ctx=ctx)
|
||||
except TypeError:
|
||||
return
|
||||
except TypeError as err:
|
||||
assert (
|
||||
"has a dynamic shape so its not possible to make a static shared memory."
|
||||
in str(err)
|
||||
)
|
||||
pytest.skip("Skipping space with dynamic shape")
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
write_to_shared_memory(space, i, sample, shared_memory)
|
||||
|
@@ -96,9 +96,6 @@ class ExampleNamedTuple(NamedTuple):
|
||||
)
|
||||
def test_roundtripping(value, expected_value):
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||
print(f"{value=}")
|
||||
print(f"{torch_to_jax(value)=}")
|
||||
print(f"{jax_to_torch(torch_to_jax(value))=}")
|
||||
roundtripped_value = jax_to_torch(torch_to_jax(value))
|
||||
assert torch_data_equivalence(roundtripped_value, expected_value)
|
||||
|
||||
|
Reference in New Issue
Block a user