mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 06:16:32 +00:00
223 lines
8.2 KiB
Python
223 lines
8.2 KiB
Python
"""Provides a generic testing environment for use in tests with custom reset, step and render functions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import types
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
import gymnasium as gym
|
|
from gymnasium import spaces
|
|
from gymnasium.core import ActType, ObsType
|
|
from gymnasium.envs.registration import EnvSpec
|
|
from gymnasium.vector import VectorEnv
|
|
from gymnasium.vector.utils import batch_space
|
|
from gymnasium.vector.vector_env import AutoresetMode
|
|
|
|
|
|
def basic_reset_func(
|
|
self,
|
|
*,
|
|
seed: int | None = None,
|
|
options: dict | None = None,
|
|
) -> tuple[ObsType, dict]:
|
|
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
|
super(GenericTestEnv, self).reset(seed=seed)
|
|
self.observation_space.seed(self.np_random_seed)
|
|
return self.observation_space.sample(), {"options": options}
|
|
|
|
|
|
def old_reset_func(self) -> ObsType:
|
|
"""An old reset function that will pass the environment check using random actions from the observation space."""
|
|
super(GenericTestEnv, self).reset()
|
|
return self.observation_space.sample()
|
|
|
|
|
|
def basic_step_func(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
|
|
"""A step function that follows the basic step api that will pass the environment check using random actions from the observation space."""
|
|
return self.observation_space.sample(), 0, False, False, {}
|
|
|
|
|
|
def old_step_func(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
|
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
|
|
return self.observation_space.sample(), 0, False, {}
|
|
|
|
|
|
def basic_render_func(self):
|
|
"""Basic render fn that does nothing."""
|
|
pass
|
|
|
|
|
|
class GenericTestEnv(gym.Env):
|
|
"""A generic testing environment for use in testing with modified environments are required."""
|
|
|
|
def __init__(
|
|
self,
|
|
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
|
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
|
reset_func: Callable = basic_reset_func,
|
|
step_func: Callable = basic_step_func,
|
|
render_func: Callable = basic_render_func,
|
|
metadata: dict[str, Any] = {"render_modes": []},
|
|
render_mode: str | None = None,
|
|
spec: EnvSpec = EnvSpec(
|
|
"TestingEnv-v0", "tests.testing_env:GenericTestEnv", max_episode_steps=100
|
|
),
|
|
):
|
|
"""Generic testing environment constructor.
|
|
|
|
Args:
|
|
action_space: The environment action space
|
|
observation_space: The environment observation space
|
|
reset_func: The environment reset function
|
|
step_func: The environment step function
|
|
render_func: The environment render function
|
|
metadata: The environment metadata
|
|
render_mode: The render mode of the environment
|
|
spec: The environment spec
|
|
"""
|
|
self.metadata = metadata
|
|
self.render_mode = render_mode
|
|
self.spec = spec
|
|
|
|
if observation_space is not None:
|
|
self.observation_space = observation_space
|
|
if action_space is not None:
|
|
self.action_space = action_space
|
|
|
|
if reset_func is not None:
|
|
self.reset = types.MethodType(reset_func, self)
|
|
if step_func is not None:
|
|
self.step = types.MethodType(step_func, self)
|
|
if render_func is not None:
|
|
self.render = types.MethodType(render_func, self)
|
|
|
|
def reset(
|
|
self,
|
|
*,
|
|
seed: int | None = None,
|
|
options: dict | None = None,
|
|
) -> ObsType | tuple[ObsType, dict]:
|
|
"""Resets the environment."""
|
|
# If you need a default working reset function, use `basic_reset_fn` above
|
|
raise NotImplementedError("TestingEnv reset_fn is not set.")
|
|
|
|
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict[str, Any]]:
|
|
"""Steps through the environment."""
|
|
raise NotImplementedError("TestingEnv step_fn is not set.")
|
|
|
|
def render(self):
|
|
"""Renders the environment."""
|
|
raise NotImplementedError("testingEnv render_fn is not set.")
|
|
|
|
|
|
def basic_vector_reset_func(
|
|
self,
|
|
*,
|
|
seed: int | None = None,
|
|
options: dict | None = None,
|
|
) -> tuple[ObsType, dict]:
|
|
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
|
super(GenericTestVectorEnv, self).reset(seed=seed)
|
|
self.observation_space.seed(self.np_random_seed)
|
|
return self.observation_space.sample(), {"options": options}
|
|
|
|
|
|
def basic_vector_step_func(
|
|
self, action: ActType
|
|
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
|
|
"""A step function that follows the basic step api that will pass the environment check using random actions from the observation space."""
|
|
obs = self.observation_space.sample()
|
|
rewards = np.zeros(self.num_envs, dtype=np.float64)
|
|
terminations = np.zeros(self.num_envs, dtype=np.bool_)
|
|
truncations = np.zeros(self.num_envs, dtype=np.bool_)
|
|
return obs, rewards, terminations, truncations, {}
|
|
|
|
|
|
def basic_vector_render_func(self):
|
|
"""Basic render fn that does nothing."""
|
|
pass
|
|
|
|
|
|
class GenericTestVectorEnv(VectorEnv):
|
|
"""A generic testing vector environment similar to GenericTestEnv.
|
|
|
|
Some tests cannot use SyncVectorEnv, e.g. when returning non-numpy arrays in the observations.
|
|
In these cases, GenericTestVectorEnv can be used to simulate a vector environment.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_envs: int = 1,
|
|
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
|
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
|
reset_func: Callable = basic_vector_reset_func,
|
|
step_func: Callable = basic_vector_step_func,
|
|
render_func: Callable = basic_vector_render_func,
|
|
metadata: dict[str, Any] = {
|
|
"render_modes": [],
|
|
"autoreset_mode": AutoresetMode.NEXT_STEP,
|
|
},
|
|
render_mode: str | None = None,
|
|
spec: EnvSpec = EnvSpec(
|
|
"TestingVectorEnv-v0",
|
|
"tests.testing_env:GenericTestVectorEnv",
|
|
max_episode_steps=100,
|
|
),
|
|
):
|
|
"""Generic testing vector environment constructor.
|
|
|
|
Args:
|
|
num_envs: The number of environments to create
|
|
action_space: The environment action space
|
|
observation_space: The environment observation space
|
|
reset_func: The environment reset function
|
|
step_func: The environment step function
|
|
render_func: The environment render function
|
|
metadata: The environment metadata
|
|
render_mode: The render mode of the environment
|
|
spec: The environment spec
|
|
"""
|
|
super().__init__()
|
|
|
|
self.num_envs = num_envs
|
|
self.metadata = metadata
|
|
self.render_mode = render_mode
|
|
self.spec = spec
|
|
|
|
# Set the single spaces and create batched spaces
|
|
self.single_observation_space = observation_space
|
|
self.single_action_space = action_space
|
|
self.observation_space = batch_space(observation_space, num_envs)
|
|
self.action_space = batch_space(action_space, num_envs)
|
|
|
|
# Bind the functions to the instance
|
|
if reset_func is not None:
|
|
self.reset = types.MethodType(reset_func, self)
|
|
if step_func is not None:
|
|
self.step = types.MethodType(step_func, self)
|
|
if render_func is not None:
|
|
self.render = types.MethodType(render_func, self)
|
|
|
|
def reset(
|
|
self,
|
|
*,
|
|
seed: int | None = None,
|
|
options: dict | None = None,
|
|
) -> tuple[ObsType, dict]:
|
|
"""Resets the environment."""
|
|
# If you need a default working reset function, use `basic_vector_reset_fn` above
|
|
raise NotImplementedError("TestingVectorEnv reset_fn is not set.")
|
|
|
|
def step(
|
|
self, action: ActType
|
|
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
|
|
"""Steps through the environment."""
|
|
raise NotImplementedError("TestingVectorEnv step_fn is not set.")
|
|
|
|
def render(self) -> tuple[Any, ...] | None:
|
|
"""Renders the environment."""
|
|
raise NotImplementedError("TestingVectorEnv render_fn is not set.")
|