2023-02-18 21:08:08 +00:00
|
|
|
"""Tests the vector wrappers work as expected."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import gymnasium as gym
|
2023-11-07 13:27:25 +00:00
|
|
|
from gymnasium.core import ObsType
|
|
|
|
from gymnasium.vector import VectorWrapper
|
2023-02-12 07:49:37 -05:00
|
|
|
|
|
|
|
|
2023-02-18 21:08:08 +00:00
|
|
|
class DummyVectorWrapper(VectorWrapper):
|
|
|
|
"""Dummy Vector wrapper that contains a counter function to logging the number of times that reset is called."""
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
def __init__(self, env):
|
2023-02-18 21:08:08 +00:00
|
|
|
"""Initialises the wrapper with the environment creating a counter variable."""
|
2023-02-12 07:49:37 -05:00
|
|
|
super().__init__(env)
|
2023-11-07 13:27:25 +00:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
self.counter = 0
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
def reset(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
seed: int | list[int] | None = None,
|
|
|
|
options: dict[str, Any] | None = None,
|
|
|
|
) -> tuple[ObsType, dict[str, Any]]:
|
2023-02-18 21:08:08 +00:00
|
|
|
"""Updates the ``counter`` each time at ``reset`` is called."""
|
2023-02-12 07:49:37 -05:00
|
|
|
self.counter += 1
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
return super().reset(seed=seed, options=options)
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
|
|
|
|
def test_vector_env_wrapper_inheritance():
|
2023-02-18 21:08:08 +00:00
|
|
|
"""Test vector environment wrapper inheritance."""
|
2023-11-07 13:27:25 +00:00
|
|
|
env = gym.make_vec("FrozenLake-v1", vectorization_mode="sync")
|
2023-02-18 21:08:08 +00:00
|
|
|
wrapped = DummyVectorWrapper(env)
|
2023-02-12 07:49:37 -05:00
|
|
|
wrapped.reset()
|
|
|
|
assert wrapped.counter == 1
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
env.close()
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
|
|
|
|
def test_vector_env_wrapper_attributes():
|
|
|
|
"""Test if `set_attr`, `call` methods for VecEnvWrapper get correctly forwarded to the vector env it is wrapping."""
|
2023-11-07 13:27:25 +00:00
|
|
|
env = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
|
|
|
wrapped = DummyVectorWrapper(
|
|
|
|
gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
|
|
|
)
|
2023-02-12 07:49:37 -05:00
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
assert np.allclose(wrapped.env.call("gravity"), env.call("gravity"))
|
2023-02-12 07:49:37 -05:00
|
|
|
env.set_attr("gravity", [20.0, 20.0, 20.0])
|
2023-11-07 13:27:25 +00:00
|
|
|
wrapped.env.set_attr("gravity", [20.0, 20.0, 20.0])
|
|
|
|
assert np.allclose(wrapped.env.get_attr("gravity"), env.get_attr("gravity"))
|
|
|
|
|
|
|
|
env.close()
|
2024-04-11 11:58:53 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_vector_env_metadata():
|
|
|
|
"""Test if `metadata` property for VectorWrapper correctly forwards to the vector env it is wrapping."""
|
|
|
|
env = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
|
|
|
wrapped = DummyVectorWrapper(
|
|
|
|
gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
|
|
|
)
|
|
|
|
|
|
|
|
assert env.metadata == wrapped.metadata
|
|
|
|
env.metadata = {"render_modes": ["rgb_array"]}
|
|
|
|
assert env.metadata != wrapped.metadata
|
|
|
|
|
|
|
|
env.close()
|