2022-06-19 21:50:07 +01:00
|
|
|
import pytest
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
import gymnasium as gym
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
2022-11-18 22:25:33 +01:00
|
|
|
from gymnasium.wrappers import TimeLimit, TransformObservation
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
2022-06-19 21:50:07 +01:00
|
|
|
from tests.wrappers.utils import has_wrapper
|
|
|
|
|
|
|
|
|
|
|
|
def test_vector_make_id():
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make("CartPole-v1")
|
2022-06-19 21:50:07 +01:00
|
|
|
assert isinstance(env, AsyncVectorEnv)
|
|
|
|
assert env.num_envs == 1
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_envs", [1, 3, 10])
|
|
|
|
def test_vector_make_num_envs(num_envs):
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make("CartPole-v1", num_envs=num_envs)
|
2022-06-19 21:50:07 +01:00
|
|
|
assert env.num_envs == num_envs
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
def test_vector_make_asynchronous():
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make("CartPole-v1", asynchronous=True)
|
2022-06-19 21:50:07 +01:00
|
|
|
assert isinstance(env, AsyncVectorEnv)
|
|
|
|
env.close()
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make("CartPole-v1", asynchronous=False)
|
2022-06-19 21:50:07 +01:00
|
|
|
assert isinstance(env, SyncVectorEnv)
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
def test_vector_make_wrappers():
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make("CartPole-v1", num_envs=2, asynchronous=False)
|
2022-06-19 21:50:07 +01:00
|
|
|
assert isinstance(env, SyncVectorEnv)
|
|
|
|
assert len(env.envs) == 2
|
|
|
|
|
|
|
|
sub_env = env.envs[0]
|
2022-09-16 23:41:27 +01:00
|
|
|
assert isinstance(sub_env, gym.Env)
|
2022-11-12 10:21:24 +00:00
|
|
|
assert sub_env.spec is not None
|
2022-06-19 21:50:07 +01:00
|
|
|
if sub_env.spec.max_episode_steps is not None:
|
|
|
|
assert has_wrapper(sub_env, TimeLimit)
|
|
|
|
|
|
|
|
assert all(
|
|
|
|
has_wrapper(sub_env, TransformObservation) is False for sub_env in env.envs
|
|
|
|
)
|
|
|
|
env.close()
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make(
|
2022-06-19 21:50:07 +01:00
|
|
|
"CartPole-v1",
|
|
|
|
num_envs=2,
|
|
|
|
asynchronous=False,
|
|
|
|
wrappers=lambda _env: TransformObservation(_env, lambda obs: obs * 2),
|
|
|
|
)
|
|
|
|
# As asynchronous environment are inaccessible, synchronous vector must be used
|
|
|
|
assert isinstance(env, SyncVectorEnv)
|
|
|
|
assert all(has_wrapper(sub_env, TransformObservation) for sub_env in env.envs)
|
|
|
|
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
def test_vector_make_disable_env_checker():
|
|
|
|
# As asynchronous environment are inaccessible, synchronous vector must be used
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make("CartPole-v1", num_envs=1, asynchronous=False)
|
2022-06-19 21:50:07 +01:00
|
|
|
assert isinstance(env, SyncVectorEnv)
|
|
|
|
assert has_wrapper(env.envs[0], PassiveEnvChecker)
|
|
|
|
env.close()
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make("CartPole-v1", num_envs=5, asynchronous=False)
|
2022-06-19 21:50:07 +01:00
|
|
|
assert isinstance(env, SyncVectorEnv)
|
|
|
|
assert has_wrapper(env.envs[0], PassiveEnvChecker)
|
|
|
|
assert all(
|
|
|
|
has_wrapper(env.envs[i], PassiveEnvChecker) is False for i in [1, 2, 3, 4]
|
|
|
|
)
|
|
|
|
env.close()
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.vector.make(
|
2022-06-19 21:50:07 +01:00
|
|
|
"CartPole-v1", num_envs=3, asynchronous=False, disable_env_checker=True
|
|
|
|
)
|
|
|
|
assert isinstance(env, SyncVectorEnv)
|
|
|
|
assert all(has_wrapper(sub_env, PassiveEnvChecker) is False for sub_env in env.envs)
|
|
|
|
env.close()
|