mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
251 lines
9.2 KiB
Python
251 lines
9.2 KiB
Python
"""Test suite for ArrayConversion wrapper."""
|
|
|
|
import importlib
|
|
import itertools
|
|
import pickle
|
|
from typing import Any, NamedTuple
|
|
|
|
import pytest
|
|
|
|
import gymnasium
|
|
|
|
|
|
array_api_compat = pytest.importorskip("array_api_compat")
|
|
array_api_extra = pytest.importorskip("array_api_extra")
|
|
|
|
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
|
|
|
|
from gymnasium.wrappers import ArrayConversion # noqa: E402
|
|
from gymnasium.wrappers.array_conversion import ( # noqa: E402
|
|
array_conversion,
|
|
module_namespace,
|
|
)
|
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
|
|
|
|
|
# Define available modules
|
|
installed_modules = []
|
|
array_api_modules = [
|
|
"numpy",
|
|
"jax.numpy",
|
|
"torch",
|
|
"cupy",
|
|
"dask.array",
|
|
"sparse",
|
|
"array_api_strict",
|
|
]
|
|
for module in array_api_modules:
|
|
try:
|
|
installed_modules.append(importlib.import_module(module))
|
|
except ImportError:
|
|
pass # Modules that are not installed are skipped
|
|
|
|
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
|
|
|
|
|
|
def xp_data_equivalence(data_1, data_2) -> bool:
|
|
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
|
|
if type(data_1) is type(data_2):
|
|
if isinstance(data_1, dict):
|
|
return data_1.keys() == data_2.keys() and all(
|
|
xp_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
|
)
|
|
elif isinstance(data_1, (tuple, list)):
|
|
return len(data_1) == len(data_2) and all(
|
|
xp_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
|
)
|
|
elif is_array_api_obj(data_1):
|
|
return array_api_extra.isclose(data_1, data_2, atol=0.00001).all()
|
|
else:
|
|
return data_1 == data_2
|
|
else:
|
|
return False
|
|
|
|
|
|
class ExampleNamedTuple(NamedTuple):
|
|
a: Any # Array API compatible object. Does not have proper typing support yet.
|
|
b: Any # Same as a
|
|
|
|
|
|
def _supports_higher_precision(xp, low_type, high_type):
|
|
"""Check if an array module supports higher precision type."""
|
|
return xp.result_type(low_type, high_type) == high_type
|
|
|
|
|
|
# When converting between array modules (source → target → source), we need to ensure that the
|
|
# precision used is supported by both modules. If either module only supports 32-bit types, we must
|
|
# use the lower precision to account for the conversion during the roundtrip.
|
|
def atleast_float32(source_xp, target_xp):
|
|
"""Return source_xp.float64 if both modules support it, otherwise source_xp.float32."""
|
|
source_supports_64 = _supports_higher_precision(
|
|
source_xp, source_xp.float32, source_xp.float64
|
|
)
|
|
target_supports_64 = _supports_higher_precision(
|
|
target_xp, target_xp.float32, target_xp.float64
|
|
)
|
|
return (
|
|
source_xp.float64
|
|
if (source_supports_64 and target_supports_64)
|
|
else source_xp.float32
|
|
)
|
|
|
|
|
|
def atleast_int32(source_xp, target_xp):
|
|
"""Return source_xp.int64 if both modules support it, otherwise source_xp.int32."""
|
|
source_supports_64 = _supports_higher_precision(
|
|
source_xp, source_xp.int32, source_xp.int64
|
|
)
|
|
target_supports_64 = _supports_higher_precision(
|
|
target_xp, target_xp.int32, target_xp.int64
|
|
)
|
|
return (
|
|
source_xp.int64
|
|
if (source_supports_64 and target_supports_64)
|
|
else source_xp.int32
|
|
)
|
|
|
|
|
|
def value_parametrization():
|
|
for source_xp, target_xp in installed_modules_combinations:
|
|
xp = module_namespace(source_xp)
|
|
source_xp = module_namespace(source_xp)
|
|
target_xp = module_namespace(target_xp)
|
|
for value, expected_value in [
|
|
(2, xp.asarray(2, dtype=atleast_int32(source_xp, target_xp))),
|
|
(
|
|
(3.0, 4),
|
|
(
|
|
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
|
|
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
|
|
),
|
|
),
|
|
(
|
|
[3.0, 4],
|
|
[
|
|
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
|
|
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
|
|
],
|
|
),
|
|
(
|
|
{
|
|
"a": 6.0,
|
|
"b": 7,
|
|
},
|
|
{
|
|
"a": xp.asarray(6.0, dtype=atleast_float32(source_xp, target_xp)),
|
|
"b": xp.asarray(7, dtype=atleast_int32(source_xp, target_xp)),
|
|
},
|
|
),
|
|
(xp.asarray(1.0, dtype=xp.float32), xp.asarray(1.0, dtype=xp.float32)),
|
|
(xp.asarray(1.0, dtype=xp.uint8), xp.asarray(1.0, dtype=xp.uint8)),
|
|
(xp.asarray([1, 2], dtype=xp.int32), xp.asarray([1, 2], dtype=xp.int32)),
|
|
(
|
|
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
|
|
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
|
|
),
|
|
(
|
|
{
|
|
"a": (
|
|
1,
|
|
xp.asarray(2.0, dtype=xp.float32),
|
|
xp.asarray([3, 4], dtype=xp.int32),
|
|
),
|
|
"b": {"c": 5},
|
|
},
|
|
{
|
|
"a": (
|
|
xp.asarray(1, dtype=atleast_int32(source_xp, target_xp)),
|
|
xp.asarray(2.0, dtype=xp.float32),
|
|
xp.asarray([3, 4], dtype=xp.int32),
|
|
),
|
|
"b": {
|
|
"c": xp.asarray(5, dtype=atleast_int32(source_xp, target_xp))
|
|
},
|
|
},
|
|
),
|
|
(
|
|
ExampleNamedTuple(
|
|
a=xp.asarray([1, 2], dtype=xp.int32),
|
|
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
|
|
),
|
|
ExampleNamedTuple(
|
|
a=xp.asarray([1, 2], dtype=xp.int32),
|
|
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
|
|
),
|
|
),
|
|
(None, None),
|
|
]:
|
|
yield (source_xp, target_xp, value, expected_value)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"source_xp,target_xp,value,expected_value", value_parametrization()
|
|
)
|
|
def test_roundtripping(source_xp, target_xp, value, expected_value):
|
|
"""Test roundtripping between different Array API compatible frameworks."""
|
|
roundtripped_value = array_conversion(
|
|
array_conversion(value, xp=target_xp), xp=source_xp
|
|
)
|
|
assert xp_data_equivalence(roundtripped_value, expected_value)
|
|
|
|
|
|
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
|
|
def test_array_conversion_wrapper(env_xp, target_xp):
|
|
# Define reset and step functions without partial to avoid pickling issues
|
|
|
|
def reset_func(self, seed=None, options=None):
|
|
"""A generic array API reset function."""
|
|
return env_xp.asarray([1.0, 2.0, 3.0]), {"data": env_xp.asarray([1, 2, 3])}
|
|
|
|
def step_func(self, action):
|
|
"""A generic array API step function."""
|
|
assert isinstance(action, type(env_xp.zeros(1)))
|
|
return (
|
|
env_xp.asarray([1, 2, 3]),
|
|
env_xp.asarray(5.0),
|
|
env_xp.asarray(True),
|
|
env_xp.asarray(False),
|
|
{"data": env_xp.asarray([1.0, 2.0])},
|
|
)
|
|
|
|
env = GenericTestEnv(reset_func=reset_func, step_func=step_func)
|
|
|
|
# Check that the reset and step for env_xp environment are as expected
|
|
obs, info = env.reset()
|
|
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
|
|
# to check against the compatible namespace of env_xp in array_api_compat
|
|
env_xp_compat = module_namespace(env_xp)
|
|
assert array_namespace(obs) is env_xp_compat
|
|
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
|
|
|
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
|
|
assert array_namespace(obs) is env_xp_compat
|
|
assert array_namespace(reward) is env_xp_compat
|
|
assert array_namespace(terminated) is env_xp_compat
|
|
assert array_namespace(truncated) is env_xp_compat
|
|
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
|
|
|
# Check that the wrapped version is correct.
|
|
target_xp_compat = module_namespace(target_xp)
|
|
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
|
obs, info = wrapped_env.reset()
|
|
assert array_namespace(obs) is target_xp_compat
|
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
|
|
|
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
|
|
obs, reward, terminated, truncated, info = wrapped_env.step(action)
|
|
assert array_namespace(obs) is target_xp_compat
|
|
assert isinstance(reward, float)
|
|
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
|
|
|
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
|
# None -> None conversion
|
|
wrapped_env.render()
|
|
|
|
# Test that the wrapped environment can be pickled
|
|
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
|
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
|
pkl = pickle.dumps(wrapped_env)
|
|
pickle.loads(pkl)
|