2019-10-25 23:20:53 +02:00
|
|
|
import numpy as np
|
2022-03-31 12:50:38 -07:00
|
|
|
import pytest
|
2019-10-25 23:20:53 +02:00
|
|
|
|
|
|
|
import gym
|
|
|
|
from gym.wrappers import RescaleAction
|
|
|
|
|
|
|
|
|
|
|
|
def test_rescale_action():
|
2022-06-16 14:29:13 +01:00
|
|
|
env = gym.make("CartPole-v1", disable_env_checker=True)
|
2019-10-25 23:20:53 +02:00
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
env = RescaleAction(env, -1, 1)
|
|
|
|
del env
|
|
|
|
|
2022-06-16 14:29:13 +01:00
|
|
|
env = gym.make("Pendulum-v1", disable_env_checker=True)
|
|
|
|
wrapped_env = RescaleAction(
|
|
|
|
gym.make("Pendulum-v1", disable_env_checker=True), -1, 1
|
|
|
|
)
|
2019-10-25 23:20:53 +02:00
|
|
|
|
|
|
|
seed = 0
|
|
|
|
|
2022-08-23 11:09:54 -04:00
|
|
|
obs, info = env.reset(seed=seed)
|
|
|
|
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
|
2019-10-25 23:20:53 +02:00
|
|
|
assert np.allclose(obs, wrapped_obs)
|
|
|
|
|
2022-08-30 19:41:59 +05:30
|
|
|
obs, reward, _, _, _ = env.step([1.5])
|
2019-10-25 23:20:53 +02:00
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
wrapped_env.step([1.5])
|
2022-08-30 19:41:59 +05:30
|
|
|
wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75])
|
2019-10-25 23:20:53 +02:00
|
|
|
|
|
|
|
assert np.allclose(obs, wrapped_obs)
|
|
|
|
assert np.allclose(reward, wrapped_reward)
|