mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 17:45:07 +00:00
Adding return_info argument to reset to allow for optional info dict as a second return value (#2546)
* initial draft of optional info dict in reset function, implemented for cartpole, tests seem to be passing * merged core.py * updated return type annotation for reset function in core.py * optional metadata with return_info from reset added for all first party environments, with corresponding tests. Incomplete implementation for wrappers and vector wrappers * removed Optional type for return_info arguments * added tests for return_info to normalize wrapper and sync_vector_env * autoformatted using black * added optional reset metadata tests to several wrappers * added return_info capability to async_vector_env.py and test to verify functionality * added optional return_info test for record_video.py * removed tests for mujoco environments * autoformatted * improved test coverage for optional reset return_info * re-removed unit test envs accidentally reintroduced in merge * removed unnecessary import * changes based on code-review * small fix to core wrapper typing and autoformatted record_epsisode_stats * small change to pass flake8 style
This commit is contained in:
14
gym/core.py
14
gym/core.py
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TypeVar, Generic, Tuple, SupportsFloat
|
||||
from typing import Optional
|
||||
from typing import TypeVar, Generic, Tuple, Union, Optional, SupportsFloat
|
||||
|
||||
import gym
|
||||
from gym import error, spaces
|
||||
@@ -71,8 +70,12 @@ class Env(Generic[ObsType, ActType]):
|
||||
|
||||
@abstractmethod
|
||||
def reset(
|
||||
self, *, seed: Optional[int] = None, options: Optional[dict] = None
|
||||
) -> ObsType:
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||
"""Resets the environment to an initial state and returns an initial
|
||||
observation.
|
||||
|
||||
@@ -84,6 +87,7 @@ class Env(Generic[ObsType, ActType]):
|
||||
|
||||
Returns:
|
||||
observation (object): the initial observation.
|
||||
info (optional dictionary): a dictionary containing extra information, this is only returned if return_info is set to true
|
||||
"""
|
||||
# Initialize the RNG if it's the first reset, or if the seed is manually passed
|
||||
if seed is not None or self.np_random is None:
|
||||
@@ -262,7 +266,7 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self, **kwargs) -> ObsType:
|
||||
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def render(self, mode="human", **kwargs):
|
||||
|
@@ -350,7 +350,13 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
x2 = max(p[0] for p in poly)
|
||||
self.cloud_poly.append((poly, x1, x2))
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self._destroy()
|
||||
self.world.contactListener_bug_workaround = ContactDetector(self)
|
||||
@@ -436,8 +442,10 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
return fraction
|
||||
|
||||
self.lidar = [LidarCallback() for _ in range(10)]
|
||||
|
||||
return self.step(np.array([0, 0, 0, 0]))[0]
|
||||
if not return_info:
|
||||
return self.step(np.array([0, 0, 0, 0]))[0]
|
||||
else:
|
||||
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||
|
||||
def step(self, action):
|
||||
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
|
||||
|
@@ -374,7 +374,13 @@ class CarRacing(gym.Env, EzPickle):
|
||||
self.track = track
|
||||
return True
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self._destroy()
|
||||
self.reward = 0.0
|
||||
@@ -395,7 +401,10 @@ class CarRacing(gym.Env, EzPickle):
|
||||
)
|
||||
self.car = Car(self.world, *self.track[0][1:4])
|
||||
|
||||
return self.step(None)[0]
|
||||
if not return_info:
|
||||
return self.step(None)[0]
|
||||
else:
|
||||
return self.step(None)[0], {}
|
||||
|
||||
def step(self, action):
|
||||
if action is not None:
|
||||
|
@@ -183,7 +183,13 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.world.DestroyBody(self.legs[0])
|
||||
self.world.DestroyBody(self.legs[1])
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self._destroy()
|
||||
self.world.contactListener_keepref = ContactDetector(self)
|
||||
@@ -288,7 +294,10 @@ class LunarLander(gym.Env, EzPickle):
|
||||
|
||||
self.drawlist = [self.lander] + self.legs
|
||||
|
||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
|
||||
if not return_info:
|
||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
|
||||
else:
|
||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
|
||||
|
||||
def _create_particle(self, mass, x, y, ttl):
|
||||
p = self.world.CreateDynamicBody(
|
||||
|
@@ -162,12 +162,21 @@ class AcrobotEnv(core.Env):
|
||||
self.action_space = spaces.Discrete(3)
|
||||
self.state = None
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
|
||||
np.float32
|
||||
)
|
||||
return self._get_ob()
|
||||
if not return_info:
|
||||
return self._get_ob()
|
||||
else:
|
||||
return self._get_ob(), {}
|
||||
|
||||
def step(self, a):
|
||||
s = self.state
|
||||
|
@@ -165,11 +165,20 @@ class CartPoleEnv(gym.Env):
|
||||
|
||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
||||
self.steps_beyond_done = None
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
if not return_info:
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
else:
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def render(self, mode="human"):
|
||||
screen_width = 600
|
||||
|
@@ -130,10 +130,19 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
self.state = np.array([position, velocity], dtype=np.float32)
|
||||
return self.state, reward, done, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
if not return_info:
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
else:
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def _height(self, xs):
|
||||
return np.sin(3 * xs) * 0.45 + 0.55
|
||||
|
@@ -93,10 +93,19 @@ class MountainCarEnv(gym.Env):
|
||||
self.state = (position, velocity)
|
||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
if not return_info:
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
else:
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def _height(self, xs):
|
||||
return np.sin(3 * xs) * 0.45 + 0.55
|
||||
|
@@ -109,12 +109,21 @@ class PendulumEnv(gym.Env):
|
||||
self.state = np.array([newth, newthdot])
|
||||
return self._get_obs(), -costs, False, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
high = np.array([np.pi, 1])
|
||||
self.state = self.np_random.uniform(low=-high, high=high)
|
||||
self.last_u = None
|
||||
return self._get_obs()
|
||||
if not return_info:
|
||||
return self._get_obs()
|
||||
else:
|
||||
return self._get_obs(), {}
|
||||
|
||||
def _get_obs(self):
|
||||
theta, thetadot = self.state
|
||||
|
@@ -103,11 +103,20 @@ class MujocoEnv(gym.Env):
|
||||
|
||||
# -----------------------------
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.sim.reset()
|
||||
ob = self.reset_model()
|
||||
return ob
|
||||
if not return_info:
|
||||
return ob
|
||||
else:
|
||||
return ob, {}
|
||||
|
||||
def set_state(self, qpos, qvel):
|
||||
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
||||
|
@@ -153,11 +153,19 @@ class BlackjackEnv(gym.Env):
|
||||
def _get_obs(self):
|
||||
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.dealer = draw_hand(self.np_random)
|
||||
self.player = draw_hand(self.np_random)
|
||||
return self._get_obs()
|
||||
if not return_info:
|
||||
return self._get_obs()
|
||||
else:
|
||||
return self._get_obs(), {}
|
||||
|
||||
def render(self, mode="human"):
|
||||
player_sum, dealer_card_value, usable_ace = self._get_obs()
|
||||
|
@@ -102,11 +102,20 @@ class CliffWalkingEnv(Env):
|
||||
self.lastaction = a
|
||||
return (int(s), r, d, {"prob": p})
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||
self.lastaction = None
|
||||
return int(self.s)
|
||||
if not return_info:
|
||||
return int(self.s)
|
||||
else:
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
def render(self, mode="human"):
|
||||
outfile = StringIO() if mode == "ansi" else sys.stdout
|
||||
|
@@ -212,11 +212,21 @@ class FrozenLakeEnv(Env):
|
||||
self.lastaction = a
|
||||
return (int(s), r, d, {"prob": p})
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||
self.lastaction = None
|
||||
return int(self.s)
|
||||
|
||||
if not return_info:
|
||||
return int(self.s)
|
||||
else:
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
def render(self, mode="human"):
|
||||
desc = self.desc.tolist()
|
||||
|
@@ -212,11 +212,20 @@ class TaxiEnv(Env):
|
||||
self.lastaction = a
|
||||
return (int(s), r, d, {"prob": p})
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||
self.lastaction = None
|
||||
return int(self.s)
|
||||
if not return_info:
|
||||
return int(self.s)
|
||||
else:
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
def render(self, mode="human"):
|
||||
outfile = StringIO() if mode == "ansi" else sys.stdout
|
||||
|
@@ -215,6 +215,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Send the calls to :obj:`reset` to each sub-environment.
|
||||
@@ -248,6 +249,8 @@ class AsyncVectorEnv(VectorEnv):
|
||||
single_kwargs = {}
|
||||
if single_seed is not None:
|
||||
single_kwargs["seed"] = single_seed
|
||||
if return_info:
|
||||
single_kwargs["return_info"] = return_info
|
||||
if options is not None:
|
||||
single_kwargs["options"] = options
|
||||
|
||||
@@ -255,7 +258,11 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state = AsyncState.WAITING_RESET
|
||||
|
||||
def reset_wait(
|
||||
self, timeout=None, seed: Optional[int] = None, options: Optional[dict] = None
|
||||
self,
|
||||
timeout=None,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
@@ -270,6 +277,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
-------
|
||||
element of :attr:`~VectorEnv.observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
infos : list of dicts containing metadata
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -300,12 +308,25 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
if return_info:
|
||||
results, infos = zip(*results)
|
||||
infos = list(infos)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations
|
||||
), infos
|
||||
else:
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
|
||||
def step_async(self, actions):
|
||||
"""Send the calls to :obj:`step` to each sub-environment.
|
||||
@@ -618,8 +639,13 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
observation = env.reset(**data)
|
||||
pipe.send((observation, True))
|
||||
if "return_info" in data and data["return_info"] == True:
|
||||
observation, info = env.reset(**data)
|
||||
pipe.send(((observation, info), True))
|
||||
else:
|
||||
observation = env.reset(**data)
|
||||
pipe.send((observation, True))
|
||||
|
||||
elif command == "step":
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
@@ -677,11 +703,18 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
observation = env.reset(**data)
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send((None, True))
|
||||
if "return_info" in data and data["return_info"] == True:
|
||||
observation, info = env.reset(**data)
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send(((None, info), True))
|
||||
else:
|
||||
observation = env.reset(**data)
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send((None, True))
|
||||
elif command == "step":
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
|
@@ -89,6 +89,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
def reset_wait(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
if seed is None:
|
||||
@@ -99,19 +100,34 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
self._dones[:] = False
|
||||
observations = []
|
||||
data_list = []
|
||||
for env, single_seed in zip(self.envs, seed):
|
||||
single_kwargs = {}
|
||||
|
||||
kwargs = {}
|
||||
if single_seed is not None:
|
||||
single_kwargs["seed"] = single_seed
|
||||
kwargs["seed"] = single_seed
|
||||
if options is not None:
|
||||
single_kwargs["options"] = options
|
||||
observation = env.reset(**single_kwargs)
|
||||
observations.append(observation)
|
||||
kwargs["options"] = options
|
||||
if return_info == True:
|
||||
kwargs["return_info"] = return_info
|
||||
|
||||
if not return_info:
|
||||
observation = env.reset(**kwargs)
|
||||
observations.append(observation)
|
||||
else:
|
||||
observation, data = env.reset(**kwargs)
|
||||
observations.append(observation)
|
||||
data_list.append(data)
|
||||
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
if not return_info:
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
else:
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations
|
||||
), data_list
|
||||
|
||||
def step_async(self, actions):
|
||||
self._actions = iterate(self.action_space, actions)
|
||||
|
@@ -49,6 +49,7 @@ class VectorEnv(gym.Env):
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
pass
|
||||
@@ -56,6 +57,7 @@ class VectorEnv(gym.Env):
|
||||
def reset_wait(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
@@ -64,6 +66,7 @@ class VectorEnv(gym.Env):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
r"""Reset all sub-environments and return a batch of initial observations.
|
||||
@@ -73,8 +76,8 @@ class VectorEnv(gym.Env):
|
||||
element of :attr:`observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
"""
|
||||
self.reset_async(seed=seed, options=options)
|
||||
return self.reset_wait(seed=seed, options=options)
|
||||
self.reset_async(seed=seed, return_info=return_info, options=options)
|
||||
return self.reset_wait(seed=seed, return_info=return_info, options=options)
|
||||
|
||||
def step_async(self, actions):
|
||||
pass
|
||||
|
@@ -63,13 +63,26 @@ class NormalizeObservation(gym.core.Wrapper):
|
||||
obs = self.normalize(np.array([obs]))[0]
|
||||
return obs, rews, dones, infos
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs = self.env.reset(**kwargs)
|
||||
def reset(
|
||||
self,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
obs = None
|
||||
info = None
|
||||
if not return_info:
|
||||
obs = self.env.reset(seed=seed, options=options)
|
||||
else:
|
||||
obs, info = self.env.reset(seed=seed, return_info=True, options=options)
|
||||
if self.is_vector_env:
|
||||
obs = self.normalize(obs)
|
||||
else:
|
||||
obs = self.normalize(np.array([obs]))[0]
|
||||
return obs
|
||||
if not return_info:
|
||||
return obs
|
||||
else:
|
||||
return obs, info
|
||||
|
||||
def normalize(self, obs):
|
||||
self.obs_rms.update(obs)
|
||||
|
@@ -55,6 +55,23 @@ def test_env(spec):
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", spec_list)
|
||||
def test_reset_info(spec):
|
||||
|
||||
with pytest.warns(None) as warnings:
|
||||
env = spec.make()
|
||||
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
obs = env.reset(return_info=False)
|
||||
assert ob_space.contains(obs)
|
||||
obs, info = env.reset(return_info=True)
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
||||
env.close()
|
||||
|
||||
|
||||
# Run a longer rollout on some environments
|
||||
def test_random_rollout():
|
||||
for env in [envs.make("CartPole-v0"), envs.make("FrozenLake-v1")]:
|
||||
|
@@ -35,13 +35,22 @@ class UnknownSpacesEnv(core.Env):
|
||||
on external resources), it is not encouraged.
|
||||
"""
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
||||
)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
return self.observation_space.sample() # Dummy observation
|
||||
if not return_info:
|
||||
return self.observation_space.sample() # Dummy observation
|
||||
else:
|
||||
return self.observation_space.sample(), {} # Dummy observation with info
|
||||
|
||||
def step(self, action):
|
||||
observation = self.observation_space.sample() # Dummy observation
|
||||
|
@@ -40,6 +40,32 @@ def test_reset_async_vector_env(shared_memory):
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset(return_info=False)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations, infos = env.reset(return_info=True)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
assert isinstance(infos, list)
|
||||
assert all([isinstance(info, dict) for info in infos])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||
|
@@ -31,6 +31,37 @@ def test_reset_sync_vector_env():
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
del observations
|
||||
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset(return_info=False)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
del observations
|
||||
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations, infos = env.reset(return_info=True)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
assert isinstance(infos, list)
|
||||
assert all([isinstance(info, dict) for info in infos])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||
def test_step_sync_vector_env(use_single_action_space):
|
||||
|
@@ -23,10 +23,19 @@ class DummyRewardEnv(gym.Env):
|
||||
self.t += 1
|
||||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: Optional[bool] = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self.t = self.return_reward_idx
|
||||
return np.array([self.t])
|
||||
if not return_info:
|
||||
return np.array([self.t])
|
||||
else:
|
||||
return np.array([self.t]), {}
|
||||
|
||||
|
||||
def make_env(return_reward_idx):
|
||||
@@ -47,6 +56,20 @@ def test_normalize_observation():
|
||||
assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4)
|
||||
|
||||
|
||||
def test_normalize_reset_info():
|
||||
env = DummyRewardEnv(return_reward_idx=0)
|
||||
env = NormalizeObservation(env)
|
||||
obs = env.reset()
|
||||
assert isinstance(obs, np.ndarray)
|
||||
del obs
|
||||
obs = env.reset(return_info=False)
|
||||
assert isinstance(obs, np.ndarray)
|
||||
del obs
|
||||
obs, info = env.reset(return_info=True)
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
|
||||
def test_normalize_return():
|
||||
env = DummyRewardEnv(return_reward_idx=0)
|
||||
env = NormalizeReward(env)
|
||||
|
21
tests/wrappers/test_order_enforcing.py
Normal file
21
tests/wrappers/test_order_enforcing.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.wrappers import OrderEnforcing
|
||||
|
||||
|
||||
def test_order_enforcing_reset_info():
|
||||
env = gym.make("CartPole-v1")
|
||||
env = OrderEnforcing(env)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs = env.reset(return_info=False)
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs, info = env.reset(return_info=True)
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
@@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.wrappers import RecordEpisodeStatistics
|
||||
|
||||
@@ -24,6 +26,18 @@ def test_record_episode_statistics(env_id, deque_size):
|
||||
assert len(env.length_queue) == deque_size
|
||||
|
||||
|
||||
def test_record_episode_statistics_reset_info():
|
||||
env = gym.make("CartPole-v1")
|
||||
env = RecordEpisodeStatistics(env)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs, info = env.reset(return_info=True)
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)]
|
||||
)
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.wrappers import (
|
||||
RecordEpisodeStatistics,
|
||||
@@ -29,6 +31,36 @@ def test_record_video_using_default_trigger():
|
||||
shutil.rmtree("videos")
|
||||
|
||||
|
||||
def test_record_video_reset_return_info():
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
ob_space = env.observation_space
|
||||
obs, info = env.reset(return_info=True)
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset(return_info=False)
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
||||
assert ob_space.contains(obs)
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
||||
assert ob_space.contains(obs)
|
||||
|
||||
|
||||
def test_record_video_step_trigger():
|
||||
env = gym.make("CartPole-v1")
|
||||
env._max_episode_steps = 20
|
||||
|
21
tests/wrappers/test_time_limit.py
Normal file
21
tests/wrappers/test_time_limit.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.wrappers import TimeLimit
|
||||
|
||||
|
||||
def test_time_limit_reset_info():
|
||||
env = gym.make("CartPole-v1")
|
||||
env = TimeLimit(env)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs = env.reset(return_info=False)
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs, info = env.reset(return_info=True)
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
Reference in New Issue
Block a user