mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-24 07:22:43 +00:00
92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
import pytest
|
|
|
|
import gym
|
|
from gym.spaces import Discrete
|
|
from gym.vector import AsyncVectorEnv, SyncVectorEnv
|
|
from gym.wrappers import TimeLimit
|
|
|
|
|
|
# An environment where termination happens after 20 steps
|
|
class DummyEnv(gym.Env):
|
|
def __init__(self):
|
|
self.action_space = Discrete(2)
|
|
self.observation_space = Discrete(2)
|
|
self.terminal_timestep = 20
|
|
|
|
self.timestep = 0
|
|
|
|
def step(self, action):
|
|
self.timestep += 1
|
|
terminated = True if self.timestep >= self.terminal_timestep else False
|
|
truncated = False
|
|
|
|
return 0, 0, terminated, truncated, {}
|
|
|
|
def reset(self):
|
|
self.timestep = 0
|
|
return 0
|
|
|
|
|
|
@pytest.mark.parametrize("time_limit", [10, 20, 30])
|
|
def test_terminated_truncated(time_limit):
|
|
test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True)
|
|
|
|
terminated = False
|
|
truncated = False
|
|
test_env.reset()
|
|
while not (terminated or truncated):
|
|
_, _, terminated, truncated, _ = test_env.step(0)
|
|
|
|
if test_env.terminal_timestep < time_limit:
|
|
assert terminated
|
|
assert not truncated
|
|
elif test_env.terminal_timestep == time_limit:
|
|
assert (
|
|
terminated
|
|
), "`terminated` should be True even when termination and truncation happen at the same step"
|
|
assert (
|
|
truncated
|
|
), "`truncated` should be True even when termination and truncation occur at same step "
|
|
else:
|
|
assert not terminated
|
|
assert truncated
|
|
|
|
|
|
def test_terminated_truncated_vector():
|
|
env0 = TimeLimit(DummyEnv(), 10, new_step_api=True)
|
|
env1 = TimeLimit(DummyEnv(), 20, new_step_api=True)
|
|
env2 = TimeLimit(DummyEnv(), 30, new_step_api=True)
|
|
|
|
async_env = AsyncVectorEnv(
|
|
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
|
|
)
|
|
async_env.reset()
|
|
terminateds = [False, False, False]
|
|
truncateds = [False, False, False]
|
|
counter = 0
|
|
while not all([x or y for x, y in zip(terminateds, truncateds)]):
|
|
counter += 1
|
|
_, _, terminateds, truncateds, _ = async_env.step(
|
|
async_env.action_space.sample()
|
|
)
|
|
print(counter)
|
|
assert counter == 20
|
|
assert all(terminateds == [False, True, True])
|
|
assert all(truncateds == [True, True, False])
|
|
|
|
sync_env = SyncVectorEnv(
|
|
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
|
|
)
|
|
sync_env.reset()
|
|
terminateds = [False, False, False]
|
|
truncateds = [False, False, False]
|
|
counter = 0
|
|
while not all([x or y for x, y in zip(terminateds, truncateds)]):
|
|
counter += 1
|
|
_, _, terminateds, truncateds, _ = sync_env.step(
|
|
async_env.action_space.sample()
|
|
)
|
|
assert counter == 20
|
|
assert all(terminateds == [False, True, True])
|
|
assert all(truncateds == [True, True, False])
|