Files
Gymnasium/tests/wrappers/vector/test_array_conversion.py

147 lines
5.8 KiB
Python
Raw Normal View History

"""Test suite for vector ArrayConversion wrapper."""
import importlib
import itertools
from functools import partial
import numpy as np
import pytest
from tests.testing_env import GenericTestVectorEnv
array_api_compat = pytest.importorskip("array_api_compat")
from array_api_compat import array_namespace # noqa: E402
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
# Define available modules
installed_modules = []
array_api_modules = [
"numpy",
"jax.numpy",
"torch",
"cupy",
"dask.array",
"sparse",
"array_api_strict",
]
for module in array_api_modules:
try:
installed_modules.append(importlib.import_module(module))
except ImportError:
pass # Modules that are not installed are skipped
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
def create_vector_env(env_xp):
_reset_func = partial(reset_func, num_envs=3, xp=env_xp)
_step_func = partial(step_func, num_envs=3, xp=env_xp)
return GenericTestVectorEnv(reset_func=_reset_func, step_func=_step_func)
def reset_func(self, seed=None, options=None, num_envs: int = 1, xp=np):
return xp.asarray([[1.0, 2.0, 3.0] * num_envs]), {
"data": xp.asarray([[1, 2, 3] * num_envs])
}
def step_func(self, action, num_envs: int = 1, xp=np):
assert isinstance(action, type(xp.zeros(1)))
return (
xp.asarray([[1, 2, 3] * num_envs]),
xp.asarray([5.0] * num_envs),
xp.asarray([False] * num_envs),
xp.asarray([False] * num_envs),
{"data": xp.asarray([[1.0, 2.0] * num_envs])},
)
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
def test_array_conversion_wrapper(env_xp, target_xp):
env_xp_compat = module_namespace(env_xp)
env = create_vector_env(env_xp_compat)
# Check that the reset and step for env_xp environment are as expected
obs, info = env.reset()
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
# to check against the compatible namespace of env_xp in array_api_compat
assert array_namespace(obs) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
assert array_namespace(obs) is env_xp_compat
assert array_namespace(reward) is env_xp_compat
assert array_namespace(terminated) is env_xp_compat
assert array_namespace(truncated) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
# Check that the wrapped version is correct.
target_xp_compat = module_namespace(target_xp)
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert array_namespace(reward) is target_xp_compat
assert array_namespace(terminated) is target_xp_compat
assert terminated.dtype == target_xp.bool
assert array_namespace(truncated) is target_xp_compat
assert truncated.dtype == target_xp.bool
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()
@pytest.mark.parametrize("wrapper", [JaxToNumpy, JaxToTorch, NumpyToTorch])
def test_specialized_wrappers(wrapper: type[JaxToNumpy | JaxToTorch | NumpyToTorch]):
if wrapper is JaxToNumpy:
jax = pytest.importorskip("jax")
env_xp, target_xp = jax.numpy, np
elif wrapper is JaxToTorch:
jax = pytest.importorskip("jax")
torch = pytest.importorskip("torch")
env_xp, target_xp = jax.numpy, torch
elif wrapper is NumpyToTorch:
torch = pytest.importorskip("torch")
env_xp, target_xp = np, torch
else:
raise TypeError(f"Unknown specialized conversion wrapper {type(wrapper)}")
env_xp_compat = module_namespace(env_xp)
target_xp_compat = module_namespace(target_xp)
# The unwrapped test env sanity check is already covered by test_array_conversion_wrapper for
# all known frameworks, including the specialized ones.
env = create_vector_env(env_xp_compat)
# Check that the wrapped version is correct.
wrapped_env = wrapper(env)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert array_namespace(reward) is target_xp_compat
assert array_namespace(terminated) is target_xp_compat
assert terminated.dtype == target_xp.bool
assert array_namespace(truncated) is target_xp_compat
assert truncated.dtype == target_xp.bool
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()