Files
Gymnasium/tests/wrappers/vector/test_array_conversion.py

147 lines
5.8 KiB
Python

"""Test suite for vector ArrayConversion wrapper."""
import importlib
import itertools
from functools import partial
import numpy as np
import pytest
from tests.testing_env import GenericTestVectorEnv
array_api_compat = pytest.importorskip("array_api_compat")
from array_api_compat import array_namespace # noqa: E402
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
# Define available modules
installed_modules = []
array_api_modules = [
"numpy",
"jax.numpy",
"torch",
"cupy",
"dask.array",
"sparse",
"array_api_strict",
]
for module in array_api_modules:
try:
installed_modules.append(importlib.import_module(module))
except ImportError:
pass # Modules that are not installed are skipped
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
def create_vector_env(env_xp):
_reset_func = partial(reset_func, num_envs=3, xp=env_xp)
_step_func = partial(step_func, num_envs=3, xp=env_xp)
return GenericTestVectorEnv(reset_func=_reset_func, step_func=_step_func)
def reset_func(self, seed=None, options=None, num_envs: int = 1, xp=np):
return xp.asarray([[1.0, 2.0, 3.0] * num_envs]), {
"data": xp.asarray([[1, 2, 3] * num_envs])
}
def step_func(self, action, num_envs: int = 1, xp=np):
assert isinstance(action, type(xp.zeros(1)))
return (
xp.asarray([[1, 2, 3] * num_envs]),
xp.asarray([5.0] * num_envs),
xp.asarray([False] * num_envs),
xp.asarray([False] * num_envs),
{"data": xp.asarray([[1.0, 2.0] * num_envs])},
)
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
def test_array_conversion_wrapper(env_xp, target_xp):
env_xp_compat = module_namespace(env_xp)
env = create_vector_env(env_xp_compat)
# Check that the reset and step for env_xp environment are as expected
obs, info = env.reset()
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
# to check against the compatible namespace of env_xp in array_api_compat
assert array_namespace(obs) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
assert array_namespace(obs) is env_xp_compat
assert array_namespace(reward) is env_xp_compat
assert array_namespace(terminated) is env_xp_compat
assert array_namespace(truncated) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
# Check that the wrapped version is correct.
target_xp_compat = module_namespace(target_xp)
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert array_namespace(reward) is target_xp_compat
assert array_namespace(terminated) is target_xp_compat
assert terminated.dtype == target_xp.bool
assert array_namespace(truncated) is target_xp_compat
assert truncated.dtype == target_xp.bool
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()
@pytest.mark.parametrize("wrapper", [JaxToNumpy, JaxToTorch, NumpyToTorch])
def test_specialized_wrappers(wrapper: type[JaxToNumpy | JaxToTorch | NumpyToTorch]):
if wrapper is JaxToNumpy:
jax = pytest.importorskip("jax")
env_xp, target_xp = jax.numpy, np
elif wrapper is JaxToTorch:
jax = pytest.importorskip("jax")
torch = pytest.importorskip("torch")
env_xp, target_xp = jax.numpy, torch
elif wrapper is NumpyToTorch:
torch = pytest.importorskip("torch")
env_xp, target_xp = np, torch
else:
raise TypeError(f"Unknown specialized conversion wrapper {type(wrapper)}")
env_xp_compat = module_namespace(env_xp)
target_xp_compat = module_namespace(target_xp)
# The unwrapped test env sanity check is already covered by test_array_conversion_wrapper for
# all known frameworks, including the specialized ones.
env = create_vector_env(env_xp_compat)
# Check that the wrapped version is correct.
wrapped_env = wrapper(env)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert array_namespace(reward) is target_xp_compat
assert array_namespace(terminated) is target_xp_compat
assert terminated.dtype == target_xp.bool
assert array_namespace(truncated) is target_xp_compat
assert truncated.dtype == target_xp.bool
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()