mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 23:12:46 +00:00
Add testing for step api compatibility functions and wrapper (#3028)
* Initial commit * Fixed tests and forced TimeLimit.truncated to always exist when truncated or terminated * Fix CI issues * pre-commit * Revert back to old language * Revert changes to step api wrapper
This commit is contained in:
@@ -45,7 +45,7 @@ def data_equivalence(data_1, data_2) -> bool:
|
||||
return data_1.keys() == data_2.keys() and all(
|
||||
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
||||
)
|
||||
elif isinstance(data_1, tuple):
|
||||
elif isinstance(data_1, (tuple, list)):
|
||||
return len(data_1) == len(data_2) and all(
|
||||
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||
)
|
||||
|
@@ -36,66 +36,41 @@ def step_to_new_api(
|
||||
assert len(step_returns) == 4
|
||||
observations, rewards, dones, infos = step_returns
|
||||
|
||||
terminateds = []
|
||||
truncateds = []
|
||||
if not is_vector_env:
|
||||
dones = [dones]
|
||||
|
||||
for i in range(len(dones)):
|
||||
# For every condition, handling - info single env / info vector env (list) / info vector env (dict)
|
||||
|
||||
# TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done`
|
||||
if (not is_vector_env and "TimeLimit.truncated" not in infos) or (
|
||||
is_vector_env
|
||||
and (
|
||||
(
|
||||
isinstance(infos, list)
|
||||
and "TimeLimit.truncated" not in infos[i]
|
||||
) # vector env, list info api
|
||||
or (
|
||||
"TimeLimit.truncated" not in infos
|
||||
or (
|
||||
"TimeLimit.truncated" in infos
|
||||
and not infos["TimeLimit.truncated"][i]
|
||||
)
|
||||
)
|
||||
# vector env, dict info api, for env i, vector mask `_TimeLimit.truncated` is not considered, to be compatible with envpool
|
||||
# For env i, `TimeLimit.truncated` not being present is treated same as being present and set to False.
|
||||
# therefore, terminated=True, truncated=True simultaneously is not allowed while using compatibility functions
|
||||
# with vector info
|
||||
)
|
||||
):
|
||||
terminateds.append(dones[i])
|
||||
truncateds.append(False)
|
||||
|
||||
# This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not.
|
||||
elif (
|
||||
infos["TimeLimit.truncated"]
|
||||
if not is_vector_env
|
||||
else (
|
||||
infos["TimeLimit.truncated"][i]
|
||||
if isinstance(infos, dict)
|
||||
else infos[i]["TimeLimit.truncated"]
|
||||
)
|
||||
):
|
||||
assert dones[i]
|
||||
terminateds.append(False)
|
||||
truncateds.append(True)
|
||||
else:
|
||||
# This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated,
|
||||
# but it also exceeded maximum timesteps at the same step. However to be compatible with envpool, and to be backward compatible
|
||||
# truncated is set to False here.
|
||||
assert dones[i]
|
||||
terminateds.append(True)
|
||||
truncateds.append(False)
|
||||
|
||||
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
|
||||
if is_vector_env is False:
|
||||
truncated = infos.pop("TimeLimit.truncated", False)
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0],
|
||||
np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0],
|
||||
dones and not truncated,
|
||||
dones and truncated,
|
||||
infos,
|
||||
)
|
||||
elif isinstance(infos, list):
|
||||
truncated = np.array(
|
||||
[info.pop("TimeLimit.truncated", False) for info in infos]
|
||||
)
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
np.logical_and(dones, np.logical_not(truncated)),
|
||||
np.logical_and(dones, truncated),
|
||||
infos,
|
||||
)
|
||||
elif isinstance(infos, dict):
|
||||
num_envs = len(dones)
|
||||
truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool))
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
np.logical_and(dones, np.logical_not(truncated)),
|
||||
np.logical_and(dones, truncated),
|
||||
infos,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
|
||||
)
|
||||
|
||||
|
||||
def step_to_old_api(
|
||||
@@ -111,44 +86,45 @@ def step_to_old_api(
|
||||
return step_returns
|
||||
else:
|
||||
assert len(step_returns) == 5
|
||||
observations, rewards, terminateds, truncateds, infos = step_returns
|
||||
dones = []
|
||||
if not is_vector_env:
|
||||
terminateds = [terminateds]
|
||||
truncateds = [truncateds]
|
||||
observations, rewards, terminated, truncated, infos = step_returns
|
||||
|
||||
n_envs = len(terminateds)
|
||||
|
||||
for i in range(n_envs):
|
||||
dones.append(terminateds[i] or truncateds[i])
|
||||
if truncateds[i]:
|
||||
if is_vector_env:
|
||||
# handle vector infos for dict and list
|
||||
if isinstance(infos, dict):
|
||||
if "TimeLimit.truncated" not in infos:
|
||||
# TODO: This should ideally not be done manually and should use vector_env's _add_info()
|
||||
infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
|
||||
infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
|
||||
|
||||
infos["TimeLimit.truncated"][i] = (
|
||||
not terminateds[i] or infos["TimeLimit.truncated"][i]
|
||||
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
|
||||
if is_vector_env is False:
|
||||
if truncated or terminated:
|
||||
infos["TimeLimit.truncated"] = truncated and not terminated
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
terminated or truncated,
|
||||
infos,
|
||||
)
|
||||
infos["_TimeLimit.truncated"][i] = True
|
||||
else:
|
||||
# if vector info is a list
|
||||
infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[
|
||||
i
|
||||
].get("TimeLimit.truncated", False)
|
||||
else:
|
||||
infos["TimeLimit.truncated"] = not terminateds[i] or infos.get(
|
||||
"TimeLimit.truncated", False
|
||||
elif isinstance(infos, list):
|
||||
for info, env_truncated, env_terminated in zip(
|
||||
infos, truncated, terminated
|
||||
):
|
||||
if env_truncated or env_terminated:
|
||||
info["TimeLimit.truncated"] = env_truncated and not env_terminated
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
np.logical_or(terminated, truncated),
|
||||
infos,
|
||||
)
|
||||
elif isinstance(infos, dict):
|
||||
if np.logical_or(np.any(truncated), np.any(terminated)):
|
||||
infos["TimeLimit.truncated"] = np.logical_and(
|
||||
truncated, np.logical_not(terminated)
|
||||
)
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
np.array(dones, dtype=np.bool_) if is_vector_env else dones[0],
|
||||
np.logical_or(terminated, truncated),
|
||||
infos,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
|
||||
)
|
||||
|
||||
|
||||
def step_api_compatibility(
|
||||
|
@@ -34,7 +34,7 @@ class TimeLimit(gym.Wrapper):
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
||||
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env, new_step_api)
|
||||
@@ -63,6 +63,9 @@ class TimeLimit(gym.Wrapper):
|
||||
self._elapsed_steps += 1
|
||||
|
||||
if self._elapsed_steps >= self._max_episode_steps:
|
||||
if self.new_step_api is True or terminated is False:
|
||||
# As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both.
|
||||
# Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding
|
||||
truncated = True
|
||||
|
||||
return step_api_compatibility(
|
||||
|
166
tests/utils/test_step_api_compatibility.py
Normal file
166
tests/utils/test_step_api_compatibility.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.utils.env_checker import data_equivalence
|
||||
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_vector_env, done_returns, expected_terminated, expected_truncated",
|
||||
(
|
||||
# Test each of the permutations for single environments with and without the old info
|
||||
(False, (0, 0, False, {"Test-info": True}), False, False),
|
||||
(False, (0, 0, False, {"TimeLimit.truncated": False}), False, False),
|
||||
(False, (0, 0, True, {}), True, False),
|
||||
(False, (0, 0, True, {"TimeLimit.truncated": True}), False, True),
|
||||
(False, (0, 0, True, {"Test-info": True}), True, False),
|
||||
# Test vectorise versions with both list and dict infos testing each permutation for sub-environments
|
||||
(
|
||||
True,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
np.array([False, True, True]),
|
||||
[{}, {}, {"TimeLimit.truncated": True}],
|
||||
),
|
||||
np.array([False, True, False]),
|
||||
np.array([False, False, True]),
|
||||
),
|
||||
(
|
||||
True,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
np.array([False, True, True]),
|
||||
{"TimeLimit.truncated": np.array([False, False, True])},
|
||||
),
|
||||
np.array([False, True, False]),
|
||||
np.array([False, False, True]),
|
||||
),
|
||||
# empty truncated info
|
||||
(
|
||||
True,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
np.array([False, True]),
|
||||
{},
|
||||
),
|
||||
np.array([False, True]),
|
||||
np.array([False, False]),
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_to_done_step_api(
|
||||
is_vector_env, done_returns, expected_terminated, expected_truncated
|
||||
):
|
||||
_, _, terminated, truncated, info = step_to_new_api(
|
||||
done_returns, is_vector_env=is_vector_env
|
||||
)
|
||||
assert np.all(terminated == expected_terminated)
|
||||
assert np.all(truncated == expected_truncated)
|
||||
|
||||
if is_vector_env is False:
|
||||
assert "TimeLimit.truncated" not in info
|
||||
elif isinstance(info, list):
|
||||
assert all("TimeLimit.truncated" not in sub_info for sub_info in info)
|
||||
else: # isinstance(info, dict)
|
||||
assert "TimeLimit.truncated" not in info
|
||||
|
||||
roundtripped_returns = step_to_old_api(
|
||||
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
|
||||
)
|
||||
assert data_equivalence(done_returns, roundtripped_returns)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_vector_env, terminated_truncated_returns, expected_done, expected_truncated",
|
||||
(
|
||||
(False, (0, 0, False, False, {"Test-info": True}), False, False),
|
||||
(False, (0, 0, True, False, {}), True, False),
|
||||
(False, (0, 0, False, True, {}), True, True),
|
||||
# (False, (), True, True), # Not possible to encode in the old step api
|
||||
# Test vector dict info
|
||||
(
|
||||
True,
|
||||
(0, 0, np.array([False, True, False]), np.array([False, False, True]), {}),
|
||||
np.array([False, True, True]),
|
||||
np.array([False, False, True]),
|
||||
),
|
||||
# Test vector dict info with no truncation
|
||||
(
|
||||
True,
|
||||
(0, 0, np.array([False, True]), np.array([False, False]), {}),
|
||||
np.array([False, True]),
|
||||
np.array([False, False]),
|
||||
),
|
||||
# Test vector list info
|
||||
(
|
||||
True,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
np.array([False, True, False]),
|
||||
np.array([False, False, True]),
|
||||
[{"Test-Info": True}, {}, {}],
|
||||
),
|
||||
np.array([False, True, True]),
|
||||
np.array([False, False, True]),
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_to_terminated_truncated_step_api(
|
||||
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
|
||||
):
|
||||
_, _, done, info = step_to_old_api(
|
||||
terminated_truncated_returns, is_vector_env=is_vector_env
|
||||
)
|
||||
assert np.all(done == expected_done)
|
||||
|
||||
if is_vector_env is False:
|
||||
if expected_done:
|
||||
assert info["TimeLimit.truncated"] == expected_truncated
|
||||
else:
|
||||
assert "TimeLimit.truncated" not in info
|
||||
elif isinstance(info, list):
|
||||
for sub_info, env_done, env_truncated in zip(
|
||||
info, expected_done, expected_truncated
|
||||
):
|
||||
if env_done:
|
||||
assert sub_info["TimeLimit.truncated"] == env_truncated
|
||||
else:
|
||||
assert "TimeLimit.truncated" not in sub_info
|
||||
else: # isinstance(info, dict)
|
||||
if np.any(expected_done):
|
||||
assert np.all(info["TimeLimit.truncated"] == expected_truncated)
|
||||
else:
|
||||
assert "TimeLimit.truncated" not in info
|
||||
|
||||
roundtripped_returns = step_to_new_api(
|
||||
(0, 0, done, info), is_vector_env=is_vector_env
|
||||
)
|
||||
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)
|
||||
|
||||
|
||||
def test_edge_case():
|
||||
# When converting between the two-step APIs this is not possible in a single case
|
||||
# terminated=True and truncated=True -> done=True and info={}
|
||||
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
|
||||
_, _, done, info = step_to_old_api((0, 0, True, True, {}))
|
||||
assert done is True
|
||||
assert info == {"TimeLimit.truncated": False}
|
||||
|
||||
# Test with vector dict info
|
||||
_, _, done, info = step_to_old_api(
|
||||
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
|
||||
)
|
||||
assert np.all(done)
|
||||
assert info == {"TimeLimit.truncated": np.array([False])}
|
||||
|
||||
# Test with vector list info
|
||||
_, _, done, info = step_to_old_api(
|
||||
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
|
||||
is_vector_env=True,
|
||||
)
|
||||
assert np.all(done)
|
||||
assert info == [{"Test-Info": True, "TimeLimit.truncated": False}]
|
@@ -1,91 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym.spaces import Discrete
|
||||
from gym.vector import AsyncVectorEnv, SyncVectorEnv
|
||||
from gym.wrappers import TimeLimit
|
||||
|
||||
|
||||
# An environment where termination happens after 20 steps
|
||||
class DummyEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Discrete(2)
|
||||
self.terminal_timestep = 20
|
||||
|
||||
self.timestep = 0
|
||||
|
||||
def step(self, action):
|
||||
self.timestep += 1
|
||||
terminated = True if self.timestep >= self.terminal_timestep else False
|
||||
truncated = False
|
||||
|
||||
return 0, 0, terminated, truncated, {}
|
||||
|
||||
def reset(self):
|
||||
self.timestep = 0
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("time_limit", [10, 20, 30])
|
||||
def test_terminated_truncated(time_limit):
|
||||
test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True)
|
||||
|
||||
terminated = False
|
||||
truncated = False
|
||||
test_env.reset()
|
||||
while not (terminated or truncated):
|
||||
_, _, terminated, truncated, _ = test_env.step(0)
|
||||
|
||||
if test_env.terminal_timestep < time_limit:
|
||||
assert terminated
|
||||
assert not truncated
|
||||
elif test_env.terminal_timestep == time_limit:
|
||||
assert (
|
||||
terminated
|
||||
), "`terminated` should be True even when termination and truncation happen at the same step"
|
||||
assert (
|
||||
truncated
|
||||
), "`truncated` should be True even when termination and truncation occur at same step "
|
||||
else:
|
||||
assert not terminated
|
||||
assert truncated
|
||||
|
||||
|
||||
def test_terminated_truncated_vector():
|
||||
env0 = TimeLimit(DummyEnv(), 10, new_step_api=True)
|
||||
env1 = TimeLimit(DummyEnv(), 20, new_step_api=True)
|
||||
env2 = TimeLimit(DummyEnv(), 30, new_step_api=True)
|
||||
|
||||
async_env = AsyncVectorEnv(
|
||||
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
|
||||
)
|
||||
async_env.reset()
|
||||
terminateds = [False, False, False]
|
||||
truncateds = [False, False, False]
|
||||
counter = 0
|
||||
while not all([x or y for x, y in zip(terminateds, truncateds)]):
|
||||
counter += 1
|
||||
_, _, terminateds, truncateds, _ = async_env.step(
|
||||
async_env.action_space.sample()
|
||||
)
|
||||
print(counter)
|
||||
assert counter == 20
|
||||
assert all(terminateds == [False, True, True])
|
||||
assert all(truncateds == [True, True, False])
|
||||
|
||||
sync_env = SyncVectorEnv(
|
||||
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
|
||||
)
|
||||
sync_env.reset()
|
||||
terminateds = [False, False, False]
|
||||
truncateds = [False, False, False]
|
||||
counter = 0
|
||||
while not all([x or y for x, y in zip(terminateds, truncateds)]):
|
||||
counter += 1
|
||||
_, _, terminateds, truncateds, _ = sync_env.step(
|
||||
async_env.action_space.sample()
|
||||
)
|
||||
assert counter == 20
|
||||
assert all(terminateds == [False, True, True])
|
||||
assert all(truncateds == [True, True, False])
|
@@ -138,6 +138,7 @@ def test_autoreset_wrapper_autoreset():
|
||||
"count": 0,
|
||||
"final_observation": np.array([3]),
|
||||
"final_info": {"count": 3},
|
||||
"TimeLimit.truncated": False,
|
||||
}
|
||||
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
Reference in New Issue
Block a user