2022-05-06 20:19:46 +05:30
|
|
|
import numpy as np
|
|
|
|
|
2022-03-31 12:50:38 -07:00
|
|
|
from gym.vector import VectorEnvWrapper, make
|
2020-08-14 14:20:56 -07:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2020-08-14 14:20:56 -07:00
|
|
|
class DummyWrapper(VectorEnvWrapper):
|
|
|
|
def __init__(self, env):
|
|
|
|
self.env = env
|
|
|
|
self.counter = 0
|
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
def reset_async(self, **kwargs):
|
2020-08-14 14:20:56 -07:00
|
|
|
super().reset_async()
|
|
|
|
self.counter += 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_vector_env_wrapper_inheritance():
|
2021-08-13 00:18:42 -04:00
|
|
|
env = make("FrozenLake-v1", asynchronous=False)
|
2020-08-14 14:20:56 -07:00
|
|
|
wrapped = DummyWrapper(env)
|
|
|
|
wrapped.reset()
|
|
|
|
assert wrapped.counter == 1
|
2022-05-06 20:19:46 +05:30
|
|
|
|
|
|
|
|
|
|
|
def test_vector_env_wrapper_attributes():
|
|
|
|
"""Test if `set_attr`, `call` methods for VecEnvWrapper get correctly forwarded to the vector env it is wrapping."""
|
|
|
|
env = make("CartPole-v1", num_envs=3)
|
|
|
|
wrapped = DummyWrapper(make("CartPole-v1", num_envs=3))
|
|
|
|
|
|
|
|
assert np.allclose(wrapped.call("gravity"), env.call("gravity"))
|
|
|
|
env.set_attr("gravity", [20.0, 20.0, 20.0])
|
|
|
|
wrapped.set_attr("gravity", [20.0, 20.0, 20.0])
|
|
|
|
assert np.allclose(wrapped.get_attr("gravity"), env.get_attr("gravity"))
|