Files
Gymnasium/tests/wrappers/test_step_compatibility.py
Mark Towers 799c8d20c1 Rename "rgb_array" to "rgb_array_list" and "single_rgb_array" to "rgb_array" (#3040)
* initial commit

* Fix CI

* Fixed CI

* third time the charm

* Fix mujoco environment render modes order

* Remove unnecessary changes

* pre-commit

* Fix tests

* Comment out test render modes

* Fix code review and readd mujoco

* pre-commit

* Fix testing envs

* Fix all GenericTestEnvs

* Do not run mujoco-py render environments
2022-09-01 09:06:42 -04:00

83 lines
2.4 KiB
Python

import pytest
import gym
from gym.spaces import Discrete
from gym.wrappers import StepAPICompatibility
class OldStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def step(self, action):
obs = self.observation_space.sample()
rew = 0
done = False
info = {}
return obs, rew, done, info
class NewStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def step(self, action):
obs = self.observation_space.sample()
rew = 0
terminated = False
truncated = False
info = {}
return obs, rew, terminated, truncated, info
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
@pytest.mark.parametrize("output_truncation_bool", [None, True])
def test_step_compatibility_to_new_api(env, output_truncation_bool):
if output_truncation_bool is None:
env = StepAPICompatibility(env())
else:
env = StepAPICompatibility(env(), output_truncation_bool)
step_returns = env.step(0)
_, _, terminated, truncated, _ = step_returns
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
def test_step_compatibility_to_old_api(env):
env = StepAPICompatibility(env(), False)
step_returns = env.step(0)
assert len(step_returns) == 4
_, _, done, _ = step_returns
assert isinstance(done, bool)
@pytest.mark.parametrize("apply_step_compatibility", [None, True, False])
def test_step_compatibility_in_make(apply_step_compatibility):
gym.register("OldStepEnv-v0", entry_point=OldStepEnv)
if apply_step_compatibility is not None:
env = gym.make(
"OldStepEnv-v0",
apply_step_compatibility=apply_step_compatibility,
disable_env_checker=True,
)
elif apply_step_compatibility is None:
env = gym.make("OldStepEnv-v0", disable_env_checker=True)
env.reset()
step_returns = env.step(0)
if apply_step_compatibility:
assert len(step_returns) == 5
_, _, terminated, truncated, _ = step_returns
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
else:
assert len(step_returns) == 4
_, _, done, _ = step_returns
assert isinstance(done, bool)
gym.envs.registry.pop("OldStepEnv-v0")