mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-28 01:07:11 +00:00
81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
import pytest
|
|
|
|
import gym
|
|
from gym.spaces import Discrete
|
|
from gym.wrappers import StepAPICompatibility
|
|
|
|
|
|
class OldStepEnv(gym.Env):
|
|
def __init__(self):
|
|
self.action_space = Discrete(2)
|
|
self.observation_space = Discrete(2)
|
|
|
|
def step(self, action):
|
|
obs = self.observation_space.sample()
|
|
rew = 0
|
|
done = False
|
|
info = {}
|
|
return obs, rew, done, info
|
|
|
|
|
|
class NewStepEnv(gym.Env):
|
|
def __init__(self):
|
|
self.action_space = Discrete(2)
|
|
self.observation_space = Discrete(2)
|
|
|
|
def step(self, action):
|
|
obs = self.observation_space.sample()
|
|
rew = 0
|
|
terminated = False
|
|
truncated = False
|
|
info = {}
|
|
return obs, rew, terminated, truncated, info
|
|
|
|
|
|
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
|
@pytest.mark.parametrize("output_truncation_bool", [None, True])
|
|
def test_step_compatibility_to_new_api(env, output_truncation_bool):
|
|
if output_truncation_bool is None:
|
|
env = StepAPICompatibility(env())
|
|
else:
|
|
env = StepAPICompatibility(env(), output_truncation_bool)
|
|
step_returns = env.step(0)
|
|
_, _, terminated, truncated, _ = step_returns
|
|
assert isinstance(terminated, bool)
|
|
assert isinstance(truncated, bool)
|
|
|
|
|
|
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
|
def test_step_compatibility_to_old_api(env):
|
|
env = StepAPICompatibility(env(), False)
|
|
step_returns = env.step(0)
|
|
assert len(step_returns) == 4
|
|
_, _, done, _ = step_returns
|
|
assert isinstance(done, bool)
|
|
|
|
|
|
@pytest.mark.parametrize("apply_step_compatibility", [None, True, False])
|
|
def test_step_compatibility_in_make(apply_step_compatibility):
|
|
gym.register("OldStepEnv-v0", entry_point=OldStepEnv)
|
|
|
|
if apply_step_compatibility is not None:
|
|
env = gym.make(
|
|
"OldStepEnv-v0",
|
|
apply_step_compatibility=apply_step_compatibility,
|
|
disable_env_checker=True,
|
|
)
|
|
elif apply_step_compatibility is None:
|
|
env = gym.make("OldStepEnv-v0", disable_env_checker=True)
|
|
|
|
env.reset()
|
|
step_returns = env.step(0)
|
|
if apply_step_compatibility:
|
|
assert len(step_returns) == 5
|
|
_, _, terminated, truncated, _ = step_returns
|
|
assert isinstance(terminated, bool)
|
|
assert isinstance(truncated, bool)
|
|
else:
|
|
assert len(step_returns) == 4
|
|
_, _, done, _ = step_returns
|
|
assert isinstance(done, bool)
|