Files
Gymnasium/tests/test_core.py

285 lines
8.4 KiB
Python
Raw Normal View History

"""Checks that the core Gymnasium API is implemented as expected."""
import re
from typing import Any, Dict, Optional, SupportsFloat, Tuple
Seeding update (#2422) * Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
2021-12-08 22:14:15 +01:00
import numpy as np
import pytest
from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper, spaces
from gymnasium.core import (
ActionWrapper,
ActType,
ObsType,
WrapperActType,
WrapperObsType,
)
from gymnasium.spaces import Box
from gymnasium.utils import seeding
2022-09-08 10:10:07 +01:00
from gymnasium.wrappers import OrderEnforcing, TimeLimit
from tests.testing_env import GenericTestEnv
# ==== Old testing code
2021-07-29 02:26:34 +02:00
class ArgumentEnv(Env):
observation_space = spaces.Box(low=0, high=1, shape=(1,))
action_space = spaces.Box(low=0, high=1, shape=(1,))
calls = 0
def __init__(self, arg):
self.calls += 1
self.arg = arg
2021-07-29 02:26:34 +02:00
class UnittestEnv(Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
Seeding update (#2422) * Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
2021-12-08 22:14:15 +01:00
super().reset(seed=seed)
return self.observation_space.sample(), {"info": "dummy"}
def step(self, action):
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})
class UnknownSpacesEnv(Env):
"""This environment defines its observation & action spaces only
after the first call to reset. Although this pattern is sometimes
necessary when implementing a new environment (e.g. if it depends
on external resources), it is not encouraged.
"""
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
Seeding update (#2422) * Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
2021-12-08 22:14:15 +01:00
super().reset(seed=seed)
self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(3)
return self.observation_space.sample(), {} # Dummy observation with info
def step(self, action):
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})
class OldStyleEnv(Env):
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)"""
def __init__(self):
pass
def reset(self):
super().reset()
return 0
def step(self, action):
return 0, 0, False, {}
class NewPropertyWrapper(Wrapper):
def __init__(
self,
env,
observation_space=None,
action_space=None,
reward_range=None,
metadata=None,
):
super().__init__(env)
if observation_space is not None:
# Only set the observation space if not None to test property forwarding
self.observation_space = observation_space
if action_space is not None:
self.action_space = action_space
if reward_range is not None:
self.reward_range = reward_range
if metadata is not None:
self.metadata = metadata
def test_env_instantiation():
# This looks like a pretty trivial, but given our usage of
# __new__, it's worth having.
2021-07-29 02:26:34 +02:00
env = ArgumentEnv("arg")
assert env.arg == "arg"
assert env.calls == 1
properties = [
{
"observation_space": spaces.Box(
low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32
)
},
{"action_space": spaces.Discrete(2)},
{"reward_range": (-1.0, 1.0)},
{"metadata": {"render_modes": ["human", "rgb_array_list"]}},
{
"observation_space": spaces.Box(
low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32
),
"action_space": spaces.Discrete(2),
},
]
@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv])
@pytest.mark.parametrize("props", properties)
def test_wrapper_property_forwarding(class_, props):
env = class_()
env = NewPropertyWrapper(env, **props)
# If UnknownSpacesEnv, then call reset to define the spaces
if isinstance(env.unwrapped, UnknownSpacesEnv):
_ = env.reset()
# Test the properties set by the wrapper
for key, value in props.items():
assert getattr(env, key) == value
# Otherwise, test if the properties are forwarded
all_properties = {"observation_space", "action_space", "reward_range", "metadata"}
for key in all_properties - props.keys():
assert getattr(env, key) == getattr(env.unwrapped, key)
def test_compatibility_with_old_style_env():
env = OldStyleEnv()
env = OrderEnforcing(env)
env = TimeLimit(env)
obs = env.reset()
assert obs == 0
# ==== New testing code
class ExampleEnv(Env):
def __init__(self):
self.observation_space = Box(0, 1)
self.action_space = Box(0, 1)
def step(
self, action: ActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
return 0, 0, False, False, {}
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
return 0, {}
def test_gymnasium_env():
env = ExampleEnv()
assert env.metadata == {"render_modes": []}
assert env.render_mode is None
assert env.reward_range == (-float("inf"), float("inf"))
assert env.spec is None
assert env._np_random is None # pyright: ignore [reportPrivateUsage]
class ExampleWrapper(Wrapper):
def __init__(self, env: Env[ObsType, ActType]):
super().__init__(env)
self.new_reward = 3
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[WrapperObsType, Dict[str, Any]]:
return super().reset(seed=seed, options=options)
def step(
self, action: WrapperActType
) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]:
obs, reward, termination, truncation, info = self.env.step(action)
return obs, self.new_reward, termination, truncation, info
def access_hidden_np_random(self):
"""This should raise an error when called as wrappers should not access their own `_np_random` instances and should use the unwrapped environments."""
return self._np_random
def test_gymnasium_wrapper():
env = ExampleEnv()
wrapper_env = ExampleWrapper(env)
assert env.metadata == wrapper_env.metadata
wrapper_env.metadata = {"render_modes": ["rgb_array"]}
assert env.metadata != wrapper_env.metadata
assert env.render_mode == wrapper_env.render_mode
assert env.reward_range == wrapper_env.reward_range
wrapper_env.reward_range = (-1.0, 1.0)
assert env.reward_range != wrapper_env.reward_range
assert env.spec == wrapper_env.spec
env.observation_space = Box(0, 1)
env.action_space = Box(0, 1)
assert env.observation_space == wrapper_env.observation_space
assert env.action_space == wrapper_env.action_space
wrapper_env.observation_space = Box(1, 2)
wrapper_env.action_space = Box(1, 2)
assert env.observation_space != wrapper_env.observation_space
assert env.action_space != wrapper_env.action_space
wrapper_env.np_random, _ = seeding.np_random()
assert (
env._np_random # pyright: ignore [reportPrivateUsage]
is env.np_random
is wrapper_env.np_random
)
assert 0 <= wrapper_env.np_random.uniform() <= 1
with pytest.raises(
AttributeError,
match=re.escape(
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
),
):
print(wrapper_env.access_hidden_np_random())
class ExampleRewardWrapper(RewardWrapper):
def reward(self, reward: SupportsFloat) -> SupportsFloat:
return 1
class ExampleObservationWrapper(ObservationWrapper):
def observation(self, observation: ObsType) -> ObsType:
return np.array([1])
class ExampleActionWrapper(ActionWrapper):
def action(self, action: ActType) -> ActType:
return np.array([1])
def test_wrapper_types():
env = GenericTestEnv()
reward_env = ExampleRewardWrapper(env)
reward_env.reset()
_, reward, _, _, _ = reward_env.step(0)
assert reward == 1
observation_env = ExampleObservationWrapper(env)
obs, _ = observation_env.reset()
assert obs == np.array([1])
obs, _, _, _, _ = observation_env.step(0)
assert obs == np.array([1])
2022-12-05 19:14:56 +00:00
env = GenericTestEnv(step_func=lambda self, action: (action, 0, False, False, {}))
action_env = ExampleActionWrapper(env)
obs, _, _, _, _ = action_env.step(0)
assert obs == np.array([1])