mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 01:50:19 +00:00
Adding return_info argument to reset to allow for optional info dict as a second return value (#2546)
* initial draft of optional info dict in reset function, implemented for cartpole, tests seem to be passing * merged core.py * updated return type annotation for reset function in core.py * optional metadata with return_info from reset added for all first party environments, with corresponding tests. Incomplete implementation for wrappers and vector wrappers * removed Optional type for return_info arguments * added tests for return_info to normalize wrapper and sync_vector_env * autoformatted using black * added optional reset metadata tests to several wrappers * added return_info capability to async_vector_env.py and test to verify functionality * added optional return_info test for record_video.py * removed tests for mujoco environments * autoformatted * improved test coverage for optional reset return_info * re-removed unit test envs accidentally reintroduced in merge * removed unnecessary import * changes based on code-review * small fix to core wrapper typing and autoformatted record_epsisode_stats * small change to pass flake8 style
This commit is contained in:
14
gym/core.py
14
gym/core.py
@@ -1,8 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TypeVar, Generic, Tuple, SupportsFloat
|
from typing import TypeVar, Generic, Tuple, Union, Optional, SupportsFloat
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import error, spaces
|
from gym import error, spaces
|
||||||
@@ -71,8 +70,12 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(
|
def reset(
|
||||||
self, *, seed: Optional[int] = None, options: Optional[dict] = None
|
self,
|
||||||
) -> ObsType:
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||||
"""Resets the environment to an initial state and returns an initial
|
"""Resets the environment to an initial state and returns an initial
|
||||||
observation.
|
observation.
|
||||||
|
|
||||||
@@ -84,6 +87,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
observation (object): the initial observation.
|
observation (object): the initial observation.
|
||||||
|
info (optional dictionary): a dictionary containing extra information, this is only returned if return_info is set to true
|
||||||
"""
|
"""
|
||||||
# Initialize the RNG if it's the first reset, or if the seed is manually passed
|
# Initialize the RNG if it's the first reset, or if the seed is manually passed
|
||||||
if seed is not None or self.np_random is None:
|
if seed is not None or self.np_random is None:
|
||||||
@@ -262,7 +266,7 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
def reset(self, **kwargs) -> ObsType:
|
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
def render(self, mode="human", **kwargs):
|
def render(self, mode="human", **kwargs):
|
||||||
|
@@ -350,7 +350,13 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
x2 = max(p[0] for p in poly)
|
x2 = max(p[0] for p in poly)
|
||||||
self.cloud_poly.append((poly, x1, x2))
|
self.cloud_poly.append((poly, x1, x2))
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self._destroy()
|
self._destroy()
|
||||||
self.world.contactListener_bug_workaround = ContactDetector(self)
|
self.world.contactListener_bug_workaround = ContactDetector(self)
|
||||||
@@ -436,8 +442,10 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
return fraction
|
return fraction
|
||||||
|
|
||||||
self.lidar = [LidarCallback() for _ in range(10)]
|
self.lidar = [LidarCallback() for _ in range(10)]
|
||||||
|
if not return_info:
|
||||||
return self.step(np.array([0, 0, 0, 0]))[0]
|
return self.step(np.array([0, 0, 0, 0]))[0]
|
||||||
|
else:
|
||||||
|
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
|
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
|
||||||
|
@@ -374,7 +374,13 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
self.track = track
|
self.track = track
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self._destroy()
|
self._destroy()
|
||||||
self.reward = 0.0
|
self.reward = 0.0
|
||||||
@@ -395,7 +401,10 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
)
|
)
|
||||||
self.car = Car(self.world, *self.track[0][1:4])
|
self.car = Car(self.world, *self.track[0][1:4])
|
||||||
|
|
||||||
|
if not return_info:
|
||||||
return self.step(None)[0]
|
return self.step(None)[0]
|
||||||
|
else:
|
||||||
|
return self.step(None)[0], {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
if action is not None:
|
if action is not None:
|
||||||
|
@@ -183,7 +183,13 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.world.DestroyBody(self.legs[0])
|
self.world.DestroyBody(self.legs[0])
|
||||||
self.world.DestroyBody(self.legs[1])
|
self.world.DestroyBody(self.legs[1])
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self._destroy()
|
self._destroy()
|
||||||
self.world.contactListener_keepref = ContactDetector(self)
|
self.world.contactListener_keepref = ContactDetector(self)
|
||||||
@@ -288,7 +294,10 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
|
|
||||||
self.drawlist = [self.lander] + self.legs
|
self.drawlist = [self.lander] + self.legs
|
||||||
|
|
||||||
|
if not return_info:
|
||||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
|
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
|
||||||
|
else:
|
||||||
|
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
|
||||||
|
|
||||||
def _create_particle(self, mass, x, y, ttl):
|
def _create_particle(self, mass, x, y, ttl):
|
||||||
p = self.world.CreateDynamicBody(
|
p = self.world.CreateDynamicBody(
|
||||||
|
@@ -162,12 +162,21 @@ class AcrobotEnv(core.Env):
|
|||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
|
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
|
||||||
np.float32
|
np.float32
|
||||||
)
|
)
|
||||||
|
if not return_info:
|
||||||
return self._get_ob()
|
return self._get_ob()
|
||||||
|
else:
|
||||||
|
return self._get_ob(), {}
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
s = self.state
|
s = self.state
|
||||||
|
@@ -165,11 +165,20 @@ class CartPoleEnv(gym.Env):
|
|||||||
|
|
||||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
||||||
self.steps_beyond_done = None
|
self.steps_beyond_done = None
|
||||||
|
if not return_info:
|
||||||
return np.array(self.state, dtype=np.float32)
|
return np.array(self.state, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
return np.array(self.state, dtype=np.float32), {}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
screen_width = 600
|
screen_width = 600
|
||||||
|
@@ -130,10 +130,19 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
self.state = np.array([position, velocity], dtype=np.float32)
|
self.state = np.array([position, velocity], dtype=np.float32)
|
||||||
return self.state, reward, done, {}
|
return self.state, reward, done, {}
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||||
|
if not return_info:
|
||||||
return np.array(self.state, dtype=np.float32)
|
return np.array(self.state, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
return np.array(self.state, dtype=np.float32), {}
|
||||||
|
|
||||||
def _height(self, xs):
|
def _height(self, xs):
|
||||||
return np.sin(3 * xs) * 0.45 + 0.55
|
return np.sin(3 * xs) * 0.45 + 0.55
|
||||||
|
@@ -93,10 +93,19 @@ class MountainCarEnv(gym.Env):
|
|||||||
self.state = (position, velocity)
|
self.state = (position, velocity)
|
||||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||||
|
if not return_info:
|
||||||
return np.array(self.state, dtype=np.float32)
|
return np.array(self.state, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
return np.array(self.state, dtype=np.float32), {}
|
||||||
|
|
||||||
def _height(self, xs):
|
def _height(self, xs):
|
||||||
return np.sin(3 * xs) * 0.45 + 0.55
|
return np.sin(3 * xs) * 0.45 + 0.55
|
||||||
|
@@ -109,12 +109,21 @@ class PendulumEnv(gym.Env):
|
|||||||
self.state = np.array([newth, newthdot])
|
self.state = np.array([newth, newthdot])
|
||||||
return self._get_obs(), -costs, False, {}
|
return self._get_obs(), -costs, False, {}
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
high = np.array([np.pi, 1])
|
high = np.array([np.pi, 1])
|
||||||
self.state = self.np_random.uniform(low=-high, high=high)
|
self.state = self.np_random.uniform(low=-high, high=high)
|
||||||
self.last_u = None
|
self.last_u = None
|
||||||
|
if not return_info:
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
else:
|
||||||
|
return self._get_obs(), {}
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
theta, thetadot = self.state
|
theta, thetadot = self.state
|
||||||
|
@@ -103,11 +103,20 @@ class MujocoEnv(gym.Env):
|
|||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.sim.reset()
|
self.sim.reset()
|
||||||
ob = self.reset_model()
|
ob = self.reset_model()
|
||||||
|
if not return_info:
|
||||||
return ob
|
return ob
|
||||||
|
else:
|
||||||
|
return ob, {}
|
||||||
|
|
||||||
def set_state(self, qpos, qvel):
|
def set_state(self, qpos, qvel):
|
||||||
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
||||||
|
@@ -153,11 +153,19 @@ class BlackjackEnv(gym.Env):
|
|||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
|
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.dealer = draw_hand(self.np_random)
|
self.dealer = draw_hand(self.np_random)
|
||||||
self.player = draw_hand(self.np_random)
|
self.player = draw_hand(self.np_random)
|
||||||
|
if not return_info:
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
else:
|
||||||
|
return self._get_obs(), {}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
player_sum, dealer_card_value, usable_ace = self._get_obs()
|
player_sum, dealer_card_value, usable_ace = self._get_obs()
|
||||||
|
@@ -102,11 +102,20 @@ class CliffWalkingEnv(Env):
|
|||||||
self.lastaction = a
|
self.lastaction = a
|
||||||
return (int(s), r, d, {"prob": p})
|
return (int(s), r, d, {"prob": p})
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||||
self.lastaction = None
|
self.lastaction = None
|
||||||
|
if not return_info:
|
||||||
return int(self.s)
|
return int(self.s)
|
||||||
|
else:
|
||||||
|
return int(self.s), {"prob": 1}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
outfile = StringIO() if mode == "ansi" else sys.stdout
|
outfile = StringIO() if mode == "ansi" else sys.stdout
|
||||||
|
@@ -212,11 +212,21 @@ class FrozenLakeEnv(Env):
|
|||||||
self.lastaction = a
|
self.lastaction = a
|
||||||
return (int(s), r, d, {"prob": p})
|
return (int(s), r, d, {"prob": p})
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||||
self.lastaction = None
|
self.lastaction = None
|
||||||
|
|
||||||
|
if not return_info:
|
||||||
return int(self.s)
|
return int(self.s)
|
||||||
|
else:
|
||||||
|
return int(self.s), {"prob": 1}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
desc = self.desc.tolist()
|
desc = self.desc.tolist()
|
||||||
|
@@ -212,11 +212,20 @@ class TaxiEnv(Env):
|
|||||||
self.lastaction = a
|
self.lastaction = a
|
||||||
return (int(s), r, d, {"prob": p})
|
return (int(s), r, d, {"prob": p})
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||||
self.lastaction = None
|
self.lastaction = None
|
||||||
|
if not return_info:
|
||||||
return int(self.s)
|
return int(self.s)
|
||||||
|
else:
|
||||||
|
return int(self.s), {"prob": 1}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
outfile = StringIO() if mode == "ansi" else sys.stdout
|
outfile = StringIO() if mode == "ansi" else sys.stdout
|
||||||
|
@@ -215,6 +215,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Send the calls to :obj:`reset` to each sub-environment.
|
"""Send the calls to :obj:`reset` to each sub-environment.
|
||||||
@@ -248,6 +249,8 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
single_kwargs = {}
|
single_kwargs = {}
|
||||||
if single_seed is not None:
|
if single_seed is not None:
|
||||||
single_kwargs["seed"] = single_seed
|
single_kwargs["seed"] = single_seed
|
||||||
|
if return_info:
|
||||||
|
single_kwargs["return_info"] = return_info
|
||||||
if options is not None:
|
if options is not None:
|
||||||
single_kwargs["options"] = options
|
single_kwargs["options"] = options
|
||||||
|
|
||||||
@@ -255,7 +258,11 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._state = AsyncState.WAITING_RESET
|
self._state = AsyncState.WAITING_RESET
|
||||||
|
|
||||||
def reset_wait(
|
def reset_wait(
|
||||||
self, timeout=None, seed: Optional[int] = None, options: Optional[dict] = None
|
self,
|
||||||
|
timeout=None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
@@ -270,6 +277,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
-------
|
-------
|
||||||
element of :attr:`~VectorEnv.observation_space`
|
element of :attr:`~VectorEnv.observation_space`
|
||||||
A batch of observations from the vectorized environment.
|
A batch of observations from the vectorized environment.
|
||||||
|
infos : list of dicts containing metadata
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
@@ -300,6 +308,19 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._raise_if_errors(successes)
|
self._raise_if_errors(successes)
|
||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
results, infos = zip(*results)
|
||||||
|
infos = list(infos)
|
||||||
|
|
||||||
|
if not self.shared_memory:
|
||||||
|
self.observations = concatenate(
|
||||||
|
self.single_observation_space, results, self.observations
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
deepcopy(self.observations) if self.copy else self.observations
|
||||||
|
), infos
|
||||||
|
else:
|
||||||
if not self.shared_memory:
|
if not self.shared_memory:
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
self.single_observation_space, results, self.observations
|
self.single_observation_space, results, self.observations
|
||||||
@@ -618,8 +639,13 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|||||||
while True:
|
while True:
|
||||||
command, data = pipe.recv()
|
command, data = pipe.recv()
|
||||||
if command == "reset":
|
if command == "reset":
|
||||||
|
if "return_info" in data and data["return_info"] == True:
|
||||||
|
observation, info = env.reset(**data)
|
||||||
|
pipe.send(((observation, info), True))
|
||||||
|
else:
|
||||||
observation = env.reset(**data)
|
observation = env.reset(**data)
|
||||||
pipe.send((observation, True))
|
pipe.send((observation, True))
|
||||||
|
|
||||||
elif command == "step":
|
elif command == "step":
|
||||||
observation, reward, done, info = env.step(data)
|
observation, reward, done, info = env.step(data)
|
||||||
if done:
|
if done:
|
||||||
@@ -677,6 +703,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
|||||||
while True:
|
while True:
|
||||||
command, data = pipe.recv()
|
command, data = pipe.recv()
|
||||||
if command == "reset":
|
if command == "reset":
|
||||||
|
if "return_info" in data and data["return_info"] == True:
|
||||||
|
observation, info = env.reset(**data)
|
||||||
|
write_to_shared_memory(
|
||||||
|
observation_space, index, observation, shared_memory
|
||||||
|
)
|
||||||
|
pipe.send(((None, info), True))
|
||||||
|
else:
|
||||||
observation = env.reset(**data)
|
observation = env.reset(**data)
|
||||||
write_to_shared_memory(
|
write_to_shared_memory(
|
||||||
observation_space, index, observation, shared_memory
|
observation_space, index, observation, shared_memory
|
||||||
|
@@ -89,6 +89,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if seed is None:
|
if seed is None:
|
||||||
@@ -99,19 +100,34 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
self._dones[:] = False
|
self._dones[:] = False
|
||||||
observations = []
|
observations = []
|
||||||
|
data_list = []
|
||||||
for env, single_seed in zip(self.envs, seed):
|
for env, single_seed in zip(self.envs, seed):
|
||||||
single_kwargs = {}
|
|
||||||
|
kwargs = {}
|
||||||
if single_seed is not None:
|
if single_seed is not None:
|
||||||
single_kwargs["seed"] = single_seed
|
kwargs["seed"] = single_seed
|
||||||
if options is not None:
|
if options is not None:
|
||||||
single_kwargs["options"] = options
|
kwargs["options"] = options
|
||||||
observation = env.reset(**single_kwargs)
|
if return_info == True:
|
||||||
|
kwargs["return_info"] = return_info
|
||||||
|
|
||||||
|
if not return_info:
|
||||||
|
observation = env.reset(**kwargs)
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
|
else:
|
||||||
|
observation, data = env.reset(**kwargs)
|
||||||
|
observations.append(observation)
|
||||||
|
data_list.append(data)
|
||||||
|
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
self.single_observation_space, observations, self.observations
|
self.single_observation_space, observations, self.observations
|
||||||
)
|
)
|
||||||
|
if not return_info:
|
||||||
return deepcopy(self.observations) if self.copy else self.observations
|
return deepcopy(self.observations) if self.copy else self.observations
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
deepcopy(self.observations) if self.copy else self.observations
|
||||||
|
), data_list
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
self._actions = iterate(self.action_space, actions)
|
self._actions = iterate(self.action_space, actions)
|
||||||
|
@@ -49,6 +49,7 @@ class VectorEnv(gym.Env):
|
|||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@@ -56,6 +57,7 @@ class VectorEnv(gym.Env):
|
|||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -64,6 +66,7 @@ class VectorEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
r"""Reset all sub-environments and return a batch of initial observations.
|
r"""Reset all sub-environments and return a batch of initial observations.
|
||||||
@@ -73,8 +76,8 @@ class VectorEnv(gym.Env):
|
|||||||
element of :attr:`observation_space`
|
element of :attr:`observation_space`
|
||||||
A batch of observations from the vectorized environment.
|
A batch of observations from the vectorized environment.
|
||||||
"""
|
"""
|
||||||
self.reset_async(seed=seed, options=options)
|
self.reset_async(seed=seed, return_info=return_info, options=options)
|
||||||
return self.reset_wait(seed=seed, options=options)
|
return self.reset_wait(seed=seed, return_info=return_info, options=options)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
pass
|
pass
|
||||||
|
@@ -63,13 +63,26 @@ class NormalizeObservation(gym.core.Wrapper):
|
|||||||
obs = self.normalize(np.array([obs]))[0]
|
obs = self.normalize(np.array([obs]))[0]
|
||||||
return obs, rews, dones, infos
|
return obs, rews, dones, infos
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(
|
||||||
obs = self.env.reset(**kwargs)
|
self,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
obs = None
|
||||||
|
info = None
|
||||||
|
if not return_info:
|
||||||
|
obs = self.env.reset(seed=seed, options=options)
|
||||||
|
else:
|
||||||
|
obs, info = self.env.reset(seed=seed, return_info=True, options=options)
|
||||||
if self.is_vector_env:
|
if self.is_vector_env:
|
||||||
obs = self.normalize(obs)
|
obs = self.normalize(obs)
|
||||||
else:
|
else:
|
||||||
obs = self.normalize(np.array([obs]))[0]
|
obs = self.normalize(np.array([obs]))[0]
|
||||||
|
if not return_info:
|
||||||
return obs
|
return obs
|
||||||
|
else:
|
||||||
|
return obs, info
|
||||||
|
|
||||||
def normalize(self, obs):
|
def normalize(self, obs):
|
||||||
self.obs_rms.update(obs)
|
self.obs_rms.update(obs)
|
||||||
|
@@ -55,6 +55,23 @@ def test_env(spec):
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("spec", spec_list)
|
||||||
|
def test_reset_info(spec):
|
||||||
|
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
env = spec.make()
|
||||||
|
|
||||||
|
ob_space = env.observation_space
|
||||||
|
obs = env.reset()
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
obs = env.reset(return_info=False)
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
obs, info = env.reset(return_info=True)
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
# Run a longer rollout on some environments
|
# Run a longer rollout on some environments
|
||||||
def test_random_rollout():
|
def test_random_rollout():
|
||||||
for env in [envs.make("CartPole-v0"), envs.make("FrozenLake-v1")]:
|
for env in [envs.make("CartPole-v0"), envs.make("FrozenLake-v1")]:
|
||||||
|
@@ -35,13 +35,22 @@ class UnknownSpacesEnv(core.Env):
|
|||||||
on external resources), it is not encouraged.
|
on external resources), it is not encouraged.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
||||||
)
|
)
|
||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
|
if not return_info:
|
||||||
return self.observation_space.sample() # Dummy observation
|
return self.observation_space.sample() # Dummy observation
|
||||||
|
else:
|
||||||
|
return self.observation_space.sample(), {} # Dummy observation with info
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation = self.observation_space.sample() # Dummy observation
|
observation = self.observation_space.sample() # Dummy observation
|
||||||
|
@@ -40,6 +40,32 @@ def test_reset_async_vector_env(shared_memory):
|
|||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
assert observations.shape == env.observation_space.shape
|
assert observations.shape == env.observation_space.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
|
observations = env.reset(return_info=False)
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
assert isinstance(env.observation_space, Box)
|
||||||
|
assert isinstance(observations, np.ndarray)
|
||||||
|
assert observations.dtype == env.observation_space.dtype
|
||||||
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
|
assert observations.shape == env.observation_space.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
|
observations, infos = env.reset(return_info=True)
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
assert isinstance(env.observation_space, Box)
|
||||||
|
assert isinstance(observations, np.ndarray)
|
||||||
|
assert observations.dtype == env.observation_space.dtype
|
||||||
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
|
assert observations.shape == env.observation_space.shape
|
||||||
|
assert isinstance(infos, list)
|
||||||
|
assert all([isinstance(info, dict) for info in infos])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||||
|
@@ -31,6 +31,37 @@ def test_reset_sync_vector_env():
|
|||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
assert observations.shape == env.observation_space.shape
|
assert observations.shape == env.observation_space.shape
|
||||||
|
|
||||||
|
del observations
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = SyncVectorEnv(env_fns)
|
||||||
|
observations = env.reset(return_info=False)
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
assert isinstance(env.observation_space, Box)
|
||||||
|
assert isinstance(observations, np.ndarray)
|
||||||
|
assert observations.dtype == env.observation_space.dtype
|
||||||
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
|
assert observations.shape == env.observation_space.shape
|
||||||
|
|
||||||
|
del observations
|
||||||
|
|
||||||
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
|
try:
|
||||||
|
env = SyncVectorEnv(env_fns)
|
||||||
|
observations, infos = env.reset(return_info=True)
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
assert isinstance(env.observation_space, Box)
|
||||||
|
assert isinstance(observations, np.ndarray)
|
||||||
|
assert observations.dtype == env.observation_space.dtype
|
||||||
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
|
assert observations.shape == env.observation_space.shape
|
||||||
|
assert isinstance(infos, list)
|
||||||
|
assert all([isinstance(info, dict) for info in infos])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||||
def test_step_sync_vector_env(use_single_action_space):
|
def test_step_sync_vector_env(use_single_action_space):
|
||||||
|
@@ -23,10 +23,19 @@ class DummyRewardEnv(gym.Env):
|
|||||||
self.t += 1
|
self.t += 1
|
||||||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: Optional[bool] = False,
|
||||||
|
options: Optional[dict] = None
|
||||||
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.t = self.return_reward_idx
|
self.t = self.return_reward_idx
|
||||||
|
if not return_info:
|
||||||
return np.array([self.t])
|
return np.array([self.t])
|
||||||
|
else:
|
||||||
|
return np.array([self.t]), {}
|
||||||
|
|
||||||
|
|
||||||
def make_env(return_reward_idx):
|
def make_env(return_reward_idx):
|
||||||
@@ -47,6 +56,20 @@ def test_normalize_observation():
|
|||||||
assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4)
|
assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_reset_info():
|
||||||
|
env = DummyRewardEnv(return_reward_idx=0)
|
||||||
|
env = NormalizeObservation(env)
|
||||||
|
obs = env.reset()
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
del obs
|
||||||
|
obs = env.reset(return_info=False)
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
del obs
|
||||||
|
obs, info = env.reset(return_info=True)
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_return():
|
def test_normalize_return():
|
||||||
env = DummyRewardEnv(return_reward_idx=0)
|
env = DummyRewardEnv(return_reward_idx=0)
|
||||||
env = NormalizeReward(env)
|
env = NormalizeReward(env)
|
||||||
|
21
tests/wrappers/test_order_enforcing.py
Normal file
21
tests/wrappers/test_order_enforcing.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym.wrappers import OrderEnforcing
|
||||||
|
|
||||||
|
|
||||||
|
def test_order_enforcing_reset_info():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env = OrderEnforcing(env)
|
||||||
|
ob_space = env.observation_space
|
||||||
|
obs = env.reset()
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
del obs
|
||||||
|
obs = env.reset(return_info=False)
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
del obs
|
||||||
|
obs, info = env.reset(return_info=True)
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
assert isinstance(info, dict)
|
@@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.wrappers import RecordEpisodeStatistics
|
from gym.wrappers import RecordEpisodeStatistics
|
||||||
|
|
||||||
@@ -24,6 +26,18 @@ def test_record_episode_statistics(env_id, deque_size):
|
|||||||
assert len(env.length_queue) == deque_size
|
assert len(env.length_queue) == deque_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_episode_statistics_reset_info():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env = RecordEpisodeStatistics(env)
|
||||||
|
ob_space = env.observation_space
|
||||||
|
obs = env.reset()
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
del obs
|
||||||
|
obs, info = env.reset(return_info=True)
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)]
|
("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)]
|
||||||
)
|
)
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.wrappers import (
|
from gym.wrappers import (
|
||||||
RecordEpisodeStatistics,
|
RecordEpisodeStatistics,
|
||||||
@@ -29,6 +31,36 @@ def test_record_video_using_default_trigger():
|
|||||||
shutil.rmtree("videos")
|
shutil.rmtree("videos")
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_video_reset_return_info():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||||
|
ob_space = env.observation_space
|
||||||
|
obs, info = env.reset(return_info=True)
|
||||||
|
env.close()
|
||||||
|
assert os.path.isdir("videos")
|
||||||
|
shutil.rmtree("videos")
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||||
|
ob_space = env.observation_space
|
||||||
|
obs = env.reset(return_info=False)
|
||||||
|
env.close()
|
||||||
|
assert os.path.isdir("videos")
|
||||||
|
shutil.rmtree("videos")
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||||
|
ob_space = env.observation_space
|
||||||
|
obs = env.reset()
|
||||||
|
env.close()
|
||||||
|
assert os.path.isdir("videos")
|
||||||
|
shutil.rmtree("videos")
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
|
||||||
|
|
||||||
def test_record_video_step_trigger():
|
def test_record_video_step_trigger():
|
||||||
env = gym.make("CartPole-v1")
|
env = gym.make("CartPole-v1")
|
||||||
env._max_episode_steps = 20
|
env._max_episode_steps = 20
|
||||||
|
21
tests/wrappers/test_time_limit.py
Normal file
21
tests/wrappers/test_time_limit.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym.wrappers import TimeLimit
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_limit_reset_info():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env = TimeLimit(env)
|
||||||
|
ob_space = env.observation_space
|
||||||
|
obs = env.reset()
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
del obs
|
||||||
|
obs = env.reset(return_info=False)
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
del obs
|
||||||
|
obs, info = env.reset(return_info=True)
|
||||||
|
assert ob_space.contains(obs)
|
||||||
|
assert isinstance(info, dict)
|
Reference in New Issue
Block a user