Files
Gymnasium/tests/wrappers/test_array_conversion.py

252 lines
9.3 KiB
Python
Raw Permalink Normal View History

"""Test suite for ArrayConversion wrapper."""
import importlib
import itertools
import pickle
from typing import Any, NamedTuple
import pytest
import gymnasium
array_api_compat = pytest.importorskip("array_api_compat")
array_api_extra = pytest.importorskip("array_api_extra")
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
from gymnasium.wrappers import ArrayConversion # noqa: E402
from gymnasium.wrappers.array_conversion import ( # noqa: E402
array_conversion,
module_namespace,
)
from tests.testing_env import GenericTestEnv # noqa: E402
# Define available modules
installed_modules = []
array_api_modules = [
"numpy",
"jax.numpy",
"torch",
"cupy",
"dask.array",
"sparse",
"array_api_strict",
]
for module in array_api_modules:
try:
installed_modules.append(importlib.import_module(module))
except ImportError:
pass # Modules that are not installed are skipped
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
def xp_data_equivalence(data_1, data_2) -> bool:
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
if type(data_1) is type(data_2):
if isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all(
xp_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
elif isinstance(data_1, (tuple, list)):
return len(data_1) == len(data_2) and all(
xp_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
elif is_array_api_obj(data_1):
xp = array_namespace(data_1)
return xp.all(array_api_extra.isclose(data_1, data_2, atol=0.00001))
else:
return data_1 == data_2
else:
return False
class ExampleNamedTuple(NamedTuple):
a: Any # Array API compatible object. Does not have proper typing support yet.
b: Any # Same as a
def _supports_higher_precision(xp, low_type, high_type):
"""Check if an array module supports higher precision type."""
return xp.result_type(low_type, high_type) == high_type
# When converting between array modules (source → target → source), we need to ensure that the
# precision used is supported by both modules. If either module only supports 32-bit types, we must
# use the lower precision to account for the conversion during the roundtrip.
def atleast_float32(source_xp, target_xp):
"""Return source_xp.float64 if both modules support it, otherwise source_xp.float32."""
source_supports_64 = _supports_higher_precision(
source_xp, source_xp.float32, source_xp.float64
)
target_supports_64 = _supports_higher_precision(
target_xp, target_xp.float32, target_xp.float64
)
return (
source_xp.float64
if (source_supports_64 and target_supports_64)
else source_xp.float32
)
def atleast_int32(source_xp, target_xp):
"""Return source_xp.int64 if both modules support it, otherwise source_xp.int32."""
source_supports_64 = _supports_higher_precision(
source_xp, source_xp.int32, source_xp.int64
)
target_supports_64 = _supports_higher_precision(
target_xp, target_xp.int32, target_xp.int64
)
return (
source_xp.int64
if (source_supports_64 and target_supports_64)
else source_xp.int32
)
def value_parametrization():
for source_xp, target_xp in installed_modules_combinations:
xp = module_namespace(source_xp)
source_xp = module_namespace(source_xp)
target_xp = module_namespace(target_xp)
for value, expected_value in [
(2, xp.asarray(2, dtype=atleast_int32(source_xp, target_xp))),
(
(3.0, 4),
(
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
),
),
(
[3.0, 4],
[
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
],
),
(
{
"a": 6.0,
"b": 7,
},
{
"a": xp.asarray(6.0, dtype=atleast_float32(source_xp, target_xp)),
"b": xp.asarray(7, dtype=atleast_int32(source_xp, target_xp)),
},
),
(xp.asarray(1.0, dtype=xp.float32), xp.asarray(1.0, dtype=xp.float32)),
(xp.asarray(1.0, dtype=xp.uint8), xp.asarray(1.0, dtype=xp.uint8)),
(xp.asarray([1, 2], dtype=xp.int32), xp.asarray([1, 2], dtype=xp.int32)),
(
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
),
(
{
"a": (
1,
xp.asarray(2.0, dtype=xp.float32),
xp.asarray([3, 4], dtype=xp.int32),
),
"b": {"c": 5},
},
{
"a": (
xp.asarray(1, dtype=atleast_int32(source_xp, target_xp)),
xp.asarray(2.0, dtype=xp.float32),
xp.asarray([3, 4], dtype=xp.int32),
),
"b": {
"c": xp.asarray(5, dtype=atleast_int32(source_xp, target_xp))
},
},
),
(
ExampleNamedTuple(
a=xp.asarray([1, 2], dtype=xp.int32),
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
),
ExampleNamedTuple(
a=xp.asarray([1, 2], dtype=xp.int32),
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
),
),
(None, None),
]:
yield (source_xp, target_xp, value, expected_value)
@pytest.mark.parametrize(
"source_xp,target_xp,value,expected_value", value_parametrization()
)
def test_roundtripping(source_xp, target_xp, value, expected_value):
"""Test roundtripping between different Array API compatible frameworks."""
roundtripped_value = array_conversion(
array_conversion(value, xp=target_xp), xp=source_xp
)
assert xp_data_equivalence(roundtripped_value, expected_value)
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
def test_array_conversion_wrapper(env_xp, target_xp):
# Define reset and step functions without partial to avoid pickling issues
def reset_func(self, seed=None, options=None):
"""A generic array API reset function."""
return env_xp.asarray([1.0, 2.0, 3.0]), {"data": env_xp.asarray([1, 2, 3])}
def step_func(self, action):
"""A generic array API step function."""
assert isinstance(action, type(env_xp.zeros(1)))
return (
env_xp.asarray([1, 2, 3]),
env_xp.asarray(5.0),
env_xp.asarray(True),
env_xp.asarray(False),
{"data": env_xp.asarray([1.0, 2.0])},
)
env = GenericTestEnv(reset_func=reset_func, step_func=step_func)
# Check that the reset and step for env_xp environment are as expected
obs, info = env.reset()
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
# to check against the compatible namespace of env_xp in array_api_compat
env_xp_compat = module_namespace(env_xp)
assert array_namespace(obs) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
assert array_namespace(obs) is env_xp_compat
assert array_namespace(reward) is env_xp_compat
assert array_namespace(terminated) is env_xp_compat
assert array_namespace(truncated) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
# Check that the wrapped version is correct.
target_xp_compat = module_namespace(target_xp)
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert isinstance(reward, float)
assert isinstance(terminated, bool) and isinstance(truncated, bool)
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()
# Test that the wrapped environment can be pickled
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
pkl = pickle.dumps(wrapped_env)
pickle.loads(pkl)