Files
Gymnasium/tests/wrappers/test_step_compatibility.py

81 lines
2.4 KiB
Python
Raw Normal View History

import pytest
import gym
from gym.spaces import Discrete
from gym.wrappers import StepAPICompatibility
class OldStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def step(self, action):
obs = self.observation_space.sample()
rew = 0
done = False
info = {}
return obs, rew, done, info
class NewStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def step(self, action):
obs = self.observation_space.sample()
rew = 0
terminated = False
truncated = False
info = {}
return obs, rew, terminated, truncated, info
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
@pytest.mark.parametrize("output_truncation_bool", [None, True])
def test_step_compatibility_to_new_api(env, output_truncation_bool):
if output_truncation_bool is None:
env = StepAPICompatibility(env())
else:
env = StepAPICompatibility(env(), output_truncation_bool)
step_returns = env.step(0)
_, _, terminated, truncated, _ = step_returns
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
def test_step_compatibility_to_old_api(env):
env = StepAPICompatibility(env(), False)
step_returns = env.step(0)
assert len(step_returns) == 4
_, _, done, _ = step_returns
assert isinstance(done, bool)
@pytest.mark.parametrize("apply_step_compatibility", [None, True, False])
def test_step_compatibility_in_make(apply_step_compatibility):
gym.register("OldStepEnv-v0", entry_point=OldStepEnv)
if apply_step_compatibility is not None:
env = gym.make(
"OldStepEnv-v0",
apply_step_compatibility=apply_step_compatibility,
disable_env_checker=True,
)
elif apply_step_compatibility is None:
env = gym.make("OldStepEnv-v0", disable_env_checker=True)
env.reset()
step_returns = env.step(0)
if apply_step_compatibility:
assert len(step_returns) == 5
_, _, terminated, truncated, _ = step_returns
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
else:
assert len(step_returns) == 4
_, _, done, _ = step_returns
assert isinstance(done, bool)