mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction."""
|
|
import numpy as np
|
|
|
|
from gymnasium.experimental.wrappers import (
|
|
ClipActionV0,
|
|
LambdaActionV0,
|
|
RescaleActionV0,
|
|
)
|
|
from gymnasium.spaces import Box
|
|
from tests.testing_env import GenericTestEnv
|
|
|
|
|
|
SEED = 42
|
|
|
|
|
|
def _record_action_step_func(self, action):
|
|
return 0, 0, False, False, {"action": action}
|
|
|
|
|
|
def test_lambda_action_wrapper():
|
|
"""Tests LambdaAction through checking that the action taken is transformed by function."""
|
|
env = GenericTestEnv(step_func=_record_action_step_func)
|
|
wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
|
|
|
|
sampled_action = wrapped_env.action_space.sample()
|
|
assert sampled_action not in env.action_space
|
|
|
|
_, _, _, _, info = wrapped_env.step(sampled_action)
|
|
assert info["action"] in env.action_space
|
|
assert sampled_action - 2 == info["action"]
|
|
|
|
|
|
def test_clip_action_wrapper():
|
|
"""Test that the action is correctly clipped to the base environment action space."""
|
|
env = GenericTestEnv(
|
|
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
|
|
step_func=_record_action_step_func,
|
|
)
|
|
wrapped_env = ClipActionV0(env)
|
|
|
|
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
|
|
assert sampled_action not in env.action_space
|
|
assert sampled_action in wrapped_env.action_space
|
|
|
|
_, _, _, _, info = wrapped_env.step(sampled_action)
|
|
assert np.all(info["action"] in env.action_space)
|
|
assert np.all(info["action"] == np.array([0, 2, 3.5]))
|
|
|
|
|
|
def test_rescale_action_wrapper():
|
|
"""Test that the action is rescale within a min / max bound."""
|
|
env = GenericTestEnv(
|
|
step_func=_record_action_step_func,
|
|
action_space=Box(np.array([0, 1]), np.array([1, 3])),
|
|
)
|
|
wrapped_env = RescaleActionV0(
|
|
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
|
|
)
|
|
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
|
|
|
|
for sample_action, expected_action in (
|
|
(
|
|
np.array([0.0, 0.5], dtype=np.float32),
|
|
np.array([0.5, 2.0], dtype=np.float32),
|
|
),
|
|
(
|
|
np.array([-5.0, 0.0], dtype=np.float32),
|
|
np.array([0.0, 1.0], dtype=np.float32),
|
|
),
|
|
(
|
|
np.array([5.0, 1.0], dtype=np.float32),
|
|
np.array([1.0, 3.0], dtype=np.float32),
|
|
),
|
|
):
|
|
assert sample_action in wrapped_env.action_space
|
|
|
|
_, _, _, _, info = wrapped_env.step(sample_action)
|
|
assert np.all(info["action"] == expected_action)
|