New Step API with terminated, truncated bools instead of done (#2752)

This commit is contained in:
Arjun KG
2022-07-10 02:18:06 +05:30
committed by GitHub
parent e3c05c2b59
commit 907b1b20dd
84 changed files with 1176 additions and 411 deletions

View File

@@ -130,11 +130,16 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
def np_random(self, value: RandomNumberGenerator):
self._np_random = value
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
"""Run one timestep of the environment's dynamics.
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
Accepts an action and returns a tuple `(observation, reward, done, info)`.
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple
(observation, reward, done, info). The latter is deprecated and will be removed in future versions.
Args:
action (ActType): an action provided by the agent
@@ -143,14 +148,21 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
observation (object): this will be an element of the environment's :attr:`observation_space`.
This may, for instance, be a numpy array containing the positions and velocities of certain objects.
reward (float): The amount of reward returned as a result of taking the action.
terminated (bool): whether a `terminal state` (as defined under the MDP of the task) is reached.
In this case further step() calls could return undefined results.
truncated (bool): whether a truncation condition outside the scope of the MDP is satisfied.
Typically a timelimit, but could also be used to indicate agent physically going out of bounds.
Can be used to end the episode prematurely before a `terminal state` is reached.
info (dictionary): `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
This might, for instance, contain: metrics that describe the agent's performance state, variables that are
hidden from observations, or individual reward terms that are combined to produce the total reward.
It also can contain information that distinguishes truncation and termination, however this is deprecated in favour
of returning two booleans, and will be removed in a future version.
(deprecated)
done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results.
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully,
a certain timelimit was exceeded, or the physics simulation has entered an invalid state.
info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal.
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
This might, for instance, contain: metrics that describe the agent's performance state, variables that are
hidden from observations, information that distinguishes truncation and termination or individual reward terms
that are combined to produce the total reward
"""
raise NotImplementedError
@@ -298,11 +310,12 @@ class Wrapper(Env[ObsType, ActType]):
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
def __init__(self, env: Env):
def __init__(self, env: Env, new_step_api: bool = False):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
Args:
env: The environment to wrap
new_step_api: Whether the wrapper's step method will output in new or old step API
"""
self.env = env
@@ -310,6 +323,13 @@ class Wrapper(Env[ObsType, ActType]):
self._observation_space: Optional[spaces.Space] = None
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
self._metadata: Optional[dict] = None
self.new_step_api = new_step_api
if not self.new_step_api:
deprecation(
"Initializing wrapper in old step API which returns one bool instead of two. "
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
)
def __getattr__(self, name):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
@@ -391,9 +411,17 @@ class Wrapper(Env[ObsType, ActType]):
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
"""Steps through the environment with action."""
return self.env.step(action)
from gym.utils.step_api_compatibility import ( # avoid circular import
step_api_compatibility,
)
return step_api_compatibility(self.env.step(action), self.new_step_api)
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
"""Resets the environment with kwargs."""
@@ -463,8 +491,13 @@ class ObservationWrapper(Wrapper):
def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
observation, reward, done, info = self.env.step(action)
return self.observation(observation), reward, done, info
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return self.observation(observation), reward, terminated, truncated, info
else:
observation, reward, done, info = step_returns
return self.observation(observation), reward, done, info
def observation(self, observation):
"""Returns a modified observation."""
@@ -497,8 +530,13 @@ class RewardWrapper(Wrapper):
def step(self, action):
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
observation, reward, done, info = self.env.step(action)
return observation, self.reward(reward), done, info
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return observation, self.reward(reward), terminated, truncated, info
else:
observation, reward, done, info = step_returns
return observation, self.reward(reward), done, info
def reward(self, reward):
"""Returns a modified ``reward``."""

View File

@@ -599,15 +599,14 @@ class BipedalWalker(gym.Env, EzPickle):
reward -= 0.00035 * MOTORS_TORQUE * np.clip(np.abs(a), 0, 1)
# normalized to about -50.0 using heuristic, more optimal agent should spend less
done = False
terminated = False
if self.game_over or pos[0] < 0:
reward = -100
done = True
terminated = True
if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP:
done = True
terminated = True
self.renderer.render_step()
return np.array(state, dtype=np.float32), reward, done, {}
return np.array(state, dtype=np.float32), reward, terminated, False, {}
def render(self, mode: str = "human"):
if self.render_mode is not None:
@@ -789,9 +788,9 @@ if __name__ == "__main__":
SUPPORT_KNEE_ANGLE = +0.1
supporting_knee_angle = SUPPORT_KNEE_ANGLE
while True:
s, r, done, info = env.step(a)
s, r, terminated, truncated, info = env.step(a)
total_reward += r
if steps % 20 == 0 or done:
if steps % 20 == 0 or terminated or truncated:
print("\naction " + str([f"{x:+0.2f}" for x in a]))
print(f"step {steps} total_reward {total_reward:+0.2f}")
print("hull " + str([f"{x:+0.2f}" for x in s[0:4]]))
@@ -854,5 +853,5 @@ if __name__ == "__main__":
a[3] = knee_todo[1]
a = np.clip(0.5 * a, -1.0, 1.0)
if done:
if terminated or truncated:
break

View File

@@ -526,8 +526,8 @@ class CarRacing(gym.Env, EzPickle):
self.state = self._render("single_state_pixels")
step_reward = 0
done = False
info = {}
terminated = False
truncated = False
if action is not None: # First step without action, called from reset()
self.reward -= 0.1
# We actually don't want to count fuel spent, we want car to be faster.
@@ -536,18 +536,17 @@ class CarRacing(gym.Env, EzPickle):
step_reward = self.reward - self.prev_reward
self.prev_reward = self.reward
if self.tile_visited_count == len(self.track) or self.new_lap:
done = True
# Termination due to finishing lap
# Truncation due to finishing lap
# This should not be treated as a failure
# but like a timeout
info["TimeLimit.truncated"] = True
truncated = True
x, y = self.car.hull.position
if abs(x) > PLAYFIELD or abs(y) > PLAYFIELD:
done = True
terminated = True
step_reward = -100
self.renderer.render_step()
return self.state, step_reward, done, info
return self.state, step_reward, terminated, truncated, {}
def render(self, mode: str = "human"):
if self.render_mode is not None:
@@ -811,13 +810,13 @@ if __name__ == "__main__":
restart = False
while True:
register_input()
s, r, done, info = env.step(a)
s, r, terminated, truncated, info = env.step(a)
total_reward += r
if steps % 200 == 0 or done:
if steps % 200 == 0 or terminated or truncated:
print("\naction " + str([f"{x:+0.2f}" for x in a]))
print(f"step {steps} total_reward {total_reward:+0.2f}")
steps += 1
isopen = env.render()
if done or restart or isopen is False:
if terminated or truncated or restart or isopen is False:
break
env.close()

View File

@@ -11,6 +11,7 @@ from gym import error, spaces
from gym.error import DependencyNotInstalled
from gym.utils import EzPickle, colorize
from gym.utils.renderer import Renderer
from gym.utils.step_api_compatibility import step_api_compatibility
try:
import Box2D
@@ -577,16 +578,15 @@ class LunarLander(gym.Env, EzPickle):
) # less fuel spent is better, about -30 for heuristic landing
reward -= s_power * 0.03
done = False
terminated = False
if self.game_over or abs(state[0]) >= 1.0:
done = True
terminated = True
reward = -100
if not self.lander.awake:
done = True
terminated = True
reward = +100
self.renderer.render_step()
return np.array(state, dtype=np.float32), reward, done, {}
return np.array(state, dtype=np.float32), reward, terminated, False, {}
def render(self, mode="human"):
if self.render_mode is not None:
@@ -771,7 +771,7 @@ def demo_heuristic_lander(env, seed=None, render=False):
s = env.reset(seed=seed)
while True:
a = heuristic(env, s)
s, r, done, info = env.step(a)
s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True)
total_reward += r
if render:
@@ -779,11 +779,11 @@ def demo_heuristic_lander(env, seed=None, render=False):
if still_open is False:
break
if steps % 20 == 0 or done:
if steps % 20 == 0 or terminated or truncated:
print("observations:", " ".join([f"{x:+0.2f}" for x in s]))
print(f"step {steps} total_reward {total_reward:+0.2f}")
steps += 1
if done:
if terminated or truncated:
break
if render:
env.close()

View File

@@ -86,12 +86,12 @@ class AcrobotEnv(core.Env):
Each parameter in the underlying state (`theta1`, `theta2`, and the two angular velocities) is initialized
uniformly between -0.1 and 0.1. This means both links are pointing downwards with some initial stochasticity.
### Episode Termination
### Episode End
The episode terminates if one of the following occurs:
1. The free end reaches the target height, which is constructed as:
The episode ends if one of the following occurs:
1. Termination: The free end reaches the target height, which is constructed as:
`-cos(theta1) - cos(theta2 + theta1) > 1.0`
2. Episode length is greater than 500 (200 for v0)
2. Truncation: Episode length is greater than 500 (200 for v0)
### Arguments
@@ -226,11 +226,11 @@ class AcrobotEnv(core.Env):
ns[2] = bound(ns[2], -self.MAX_VEL_1, self.MAX_VEL_1)
ns[3] = bound(ns[3], -self.MAX_VEL_2, self.MAX_VEL_2)
self.state = ns
terminal = self._terminal()
reward = -1.0 if not terminal else 0.0
terminated = self._terminal()
reward = -1.0 if not terminated else 0.0
self.renderer.render_step()
return self._get_ob(), reward, terminal, {}
return (self._get_ob(), reward, terminated, False, {})
def _get_ob(self):
s = self.state

View File

@@ -65,12 +65,13 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
All observations are assigned a uniformly random value in `(-0.05, 0.05)`
### Episode Termination
### Episode End
The episode terminates if any one of the following occurs:
1. Pole Angle is greater than ±12°
2. Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
3. Episode length is greater than 500 (200 for v0)
The episode ends if any one of the following occurs:
1. Termination: Pole Angle is greater than ±12°
2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
3. Truncation: Episode length is greater than 500 (200 for v0)
### Arguments
@@ -126,7 +127,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.isopen = True
self.state = None
self.steps_beyond_done = None
self.steps_beyond_terminated = None
def step(self, action):
err_msg = f"{action!r} ({type(action)}) invalid"
@@ -160,32 +161,32 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.state = (x, x_dot, theta, theta_dot)
done = bool(
terminated = bool(
x < -self.x_threshold
or x > self.x_threshold
or theta < -self.theta_threshold_radians
or theta > self.theta_threshold_radians
)
if not done:
if not terminated:
reward = 1.0
elif self.steps_beyond_done is None:
elif self.steps_beyond_terminated is None:
# Pole just fell!
self.steps_beyond_done = 0
self.steps_beyond_terminated = 0
reward = 1.0
else:
if self.steps_beyond_done == 0:
if self.steps_beyond_terminated == 0:
logger.warn(
"You are calling 'step()' even though this "
"environment has already returned done = True. You "
"should always call 'reset()' once you receive 'done = "
"environment has already returned terminated = True. You "
"should always call 'reset()' once you receive 'terminated = "
"True' -- any further steps are undefined behavior."
)
self.steps_beyond_done += 1
self.steps_beyond_terminated += 1
reward = 0.0
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), reward, done, {}
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
def reset(
self,
@@ -201,7 +202,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
options, -0.05, 0.05 # default low
) # default high
self.state = self.np_random.uniform(low=low, high=high, size=(4,))
self.steps_beyond_done = None
self.steps_beyond_terminated = None
self.renderer.reset()
self.renderer.render_step()
if not return_info:

View File

@@ -84,11 +84,11 @@ class Continuous_MountainCarEnv(gym.Env):
The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`.
The starting velocity of the car is always assigned to 0.
### Episode Termination
### Episode End
The episode terminates if either of the following happens:
1. The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill)
2. The length of the episode is 999.
The episode ends if either of the following happens:
1. Termination: The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill)
2. Truncation: The length of the episode is 999.
### Arguments
@@ -161,17 +161,18 @@ class Continuous_MountainCarEnv(gym.Env):
velocity = 0
# Convert a possible numpy bool to a Python bool.
done = bool(position >= self.goal_position and velocity >= self.goal_velocity)
terminated = bool(
position >= self.goal_position and velocity >= self.goal_velocity
)
reward = 0
if done:
if terminated:
reward = 100.0
reward -= math.pow(action[0], 2) * 0.1
self.state = np.array([position, velocity], dtype=np.float32)
self.renderer.render_step()
return self.state, reward, done, {}
return self.state, reward, terminated, False, {}
def reset(
self,

View File

@@ -78,11 +78,11 @@ class MountainCarEnv(gym.Env):
The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*.
The starting velocity of the car is always assigned to 0.
### Episode Termination
### Episode End
The episode terminates if either of the following happens:
1. The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill)
2. The length of the episode is 200.
The episode ends if either of the following happens:
1. Termination: The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill)
2. Truncation: The length of the episode is 200.
### Arguments
@@ -139,13 +139,14 @@ class MountainCarEnv(gym.Env):
if position == self.min_position and velocity < 0:
velocity = 0
done = bool(position >= self.goal_position and velocity >= self.goal_velocity)
terminated = bool(
position >= self.goal_position and velocity >= self.goal_velocity
)
reward = -1.0
self.state = (position, velocity)
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), reward, done, {}
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
def reset(
self,

View File

@@ -68,9 +68,9 @@ class PendulumEnv(gym.Env):
The starting state is a random angle in *[-pi, pi]* and a random angular velocity in *[-1,1]*.
### Episode Termination
### Episode Truncation
The episode terminates at 200 time steps.
The episode truncates at 200 time steps.
### Arguments
@@ -136,7 +136,7 @@ class PendulumEnv(gym.Env):
self.state = np.array([newth, newthdot])
self.renderer.render_step()
return self._get_obs(), -costs, False, {}
return self._get_obs(), -costs, False, False, {}
def reset(
self,

View File

@@ -41,13 +41,16 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
survive_reward = 1.0
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
state = self.state_vector()
notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
done = not notdone
not_terminated = (
np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
)
terminated = not not_terminated
ob = self._get_obs()
return (
ob,
reward,
done,
terminated,
False,
dict(
reward_forward=forward_reward,
reward_ctrl=-ctrl_cost,

View File

@@ -97,9 +97,9 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = not self.is_healthy if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
return terminated
def step(self, action):
xy_position_before = self.get_body_com("torso")[:2].copy()
@@ -121,7 +121,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
self.renderer.render_step()
reward = rewards - costs
done = self.done
terminated = self.terminated
observation = self._get_obs()
info = {
"reward_forward": forward_reward,
@@ -136,7 +136,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
"forward_reward": forward_reward,
}
return observation, reward, done, info
return observation, reward, terminated, False, info
def _get_obs(self):
position = self.sim.data.qpos.flat.copy()

View File

@@ -124,19 +124,19 @@ class AntEnv(MujocoEnv, utils.EzPickle):
to be slightly high, thereby indicating a standing up ant. The initial orientation
is designed to make it face forward as well.
### Episode Termination
### Episode End
The ant is said to be unhealthy if any of the following happens:
1. Any of the state space values is no longer finite
2. The z-coordinate of the torso is **not** in the closed interval given by `healthy_z_range` (defaults to [0.2, 1.0])
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
the episode terminates when any of the following happens:
the episode ends when any of the following happens:
1. The episode duration reaches a 1000 timesteps
2. The ant is unhealthy
1. Termination: The episode duration reaches a 1000 timesteps
2. Truncation: The ant is unhealthy
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
### Arguments
@@ -263,9 +263,9 @@ class AntEnv(MujocoEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = not self.is_healthy if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
return terminated
def step(self, action):
xy_position_before = self.get_body_com("torso")[:2].copy()
@@ -282,7 +282,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
costs = ctrl_cost = self.control_cost(action)
done = self.done
terminated = self.terminated
observation = self._get_obs()
info = {
"reward_forward": forward_reward,
@@ -303,7 +303,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
reward = rewards - costs
self.renderer.render_step()
return observation, reward, done, info
return observation, reward, terminated, False, info
def _get_obs(self):
position = self.data.qpos.flat.copy()

View File

@@ -35,8 +35,14 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
reward_ctrl = -0.1 * np.square(action).sum()
reward_run = (xposafter - xposbefore) / self.dt
reward = reward_ctrl + reward_run
done = False
return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl)
terminated = False
return (
ob,
reward,
terminated,
False,
dict(reward_run=reward_run, reward_ctrl=reward_ctrl),
)
def _get_obs(self):
return np.concatenate(

View File

@@ -75,7 +75,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
observation = self._get_obs()
reward = forward_reward - ctrl_cost
done = False
terminated = False
info = {
"x_position": x_position_after,
"x_velocity": x_velocity,
@@ -83,7 +83,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
"reward_ctrl": -ctrl_cost,
}
return observation, reward, done, info
return observation, reward, terminated, False, info
def _get_obs(self):
position = self.sim.data.qpos.flat.copy()

View File

@@ -98,8 +98,8 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
normal noise with a mean of 0 and standard deviation of `reset_noise_scale` is added to the
initial velocity values of all zeros.
### Episode Termination
The episode terminates when the episode length is greater than 1000.
### Episode End
The episode truncates when the episode length is greater than 1000.
### Arguments
@@ -192,7 +192,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
observation = self._get_obs()
reward = forward_reward - ctrl_cost
done = False
terminated = False
info = {
"x_position": x_position_after,
"x_velocity": x_velocity,
@@ -201,7 +201,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
}
self.renderer.render_step()
return observation, reward, done, info
return observation, reward, terminated, False, info
def _get_obs(self):
position = self.data.qpos.flat.copy()

View File

@@ -36,14 +36,14 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
reward += alive_bonus
reward -= 1e-3 * np.square(a).sum()
s = self.state_vector()
done = not (
terminated = not (
np.isfinite(s).all()
and (np.abs(s[2:]) < 100).all()
and (height > 0.7)
and (abs(ang) < 0.2)
)
ob = self._get_obs()
return ob, reward, done, {}
return ob, reward, terminated, False, {}
def _get_obs(self):
return np.concatenate(

View File

@@ -101,9 +101,9 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = not self.is_healthy if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
return terminated
def _get_obs(self):
position = self.sim.data.qpos.flat.copy()
@@ -133,13 +133,13 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
observation = self._get_obs()
reward = rewards - costs
done = self.done
terminated = self.terminated
info = {
"x_position": x_position_after,
"x_velocity": x_velocity,
}
return observation, reward, done, info
return observation, reward, terminated, False, info
def reset_model(self):
noise_low = -self._reset_noise_scale

View File

@@ -87,7 +87,7 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
(0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise
in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity.
### Episode Termination
### Episode End
The hopper is said to be unhealthy if any of the following happens:
1. An element of `observation[1:]` (if `exclude_current_positions_from_observation=True`, else `observation[2:]`) is no longer contained in the closed interval specified by the argument `healthy_state_range`
@@ -95,12 +95,12 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
3. The angle (`observation[1]` if `exclude_current_positions_from_observation=True`, else `observation[2]`) is no longer contained in the closed interval specified by the argument `healthy_angle_range`
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
the episode terminates when any of the following happens:
the episode ends when any of the following happens:
1. The episode duration reaches a 1000 timesteps
2. The hopper is unhealthy
1. Truncation: The episode duration reaches a 1000 timesteps
2. Termination: The hopper is unhealthy
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
### Arguments
@@ -223,9 +223,9 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = not self.is_healthy if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
return terminated
def _get_obs(self):
position = self.data.qpos.flat.copy()
@@ -253,14 +253,14 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
observation = self._get_obs()
reward = rewards - costs
done = self.done
terminated = self.terminated
info = {
"x_position": x_position_after,
"x_velocity": x_velocity,
}
self.renderer.render_step()
return observation, reward, done, info
return observation, reward, terminated, False, info
def reset_model(self):
noise_low = -self._reset_noise_scale

View File

@@ -60,11 +60,12 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
quad_impact_cost = min(quad_impact_cost, 10)
reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus
qpos = self.sim.data.qpos
done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
terminated = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
return (
self._get_obs(),
reward,
done,
terminated,
False,
dict(
reward_linvel=lin_vel_cost,
reward_quadctrl=-quad_ctrl_cost,

View File

@@ -99,9 +99,9 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = (not self.is_healthy) if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False
return terminated
def _get_obs(self):
position = self.sim.data.qpos.flat.copy()
@@ -148,7 +148,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
observation = self._get_obs()
reward = rewards - costs
done = self.done
terminated = self.terminated
info = {
"reward_linvel": forward_reward,
"reward_quadctrl": -ctrl_cost,
@@ -162,7 +162,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
"forward_reward": forward_reward,
}
return observation, reward, done, info
return observation, reward, terminated, False, info
def reset_model(self):
noise_low = -self._reset_noise_scale

View File

@@ -165,18 +165,17 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
selected to be high, thereby indicating a standing up humanoid. The initial
orientation is designed to make it face forward as well.
### Episode Termination
### Episode End
The humanoid is said to be unhealthy if the z-position of the torso is no longer contained in the
closed interval specified by the argument `healthy_z_range`.
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
the episode terminates when any of the following happens:
the episode ends when any of the following happens:
1. The episode duration reaches a 1000 timesteps
3. The humanoid is unhealthy
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
1. Truncation: The episode duration reaches a 1000 timesteps
3. Termination: The humanoid is unhealthy
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
### Arguments
@@ -281,9 +280,9 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = (not self.is_healthy) if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False
return terminated
def _get_obs(self):
position = self.data.qpos.flat.copy()
@@ -326,7 +325,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
observation = self._get_obs()
reward = rewards - ctrl_cost
done = self.done
terminated = self.terminated
info = {
"reward_linvel": forward_reward,
"reward_quadctrl": -ctrl_cost,
@@ -340,7 +339,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
}
self.renderer.render_step()
return observation, reward, done, info
return observation, reward, terminated, False, info
def reset_model(self):
noise_low = -self._reset_noise_scale

View File

@@ -56,11 +56,11 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
self.renderer.render_step()
done = bool(False)
return (
self._get_obs(),
reward,
done,
False,
False,
dict(
reward_linup=uph_cost,
reward_quadctrl=-quad_ctrl_cost,

View File

@@ -151,11 +151,11 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
to be low, thereby indicating a laying down humanoid. The initial orientation is
designed to make it face forward as well.
### Episode Termination
The episode terminates when any of the following happens:
### Episode End
The episode ends when any of the following happens:
1. The episode duration reaches a 1000 timesteps
2. Any of the state space values is no longer finite
1. Truncation: The episode duration reaches a 1000 timesteps
2. Termination: Any of the state space values is no longer finite
### Arguments
@@ -228,11 +228,11 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
self.renderer.render_step()
done = bool(False)
return (
self._get_obs(),
reward,
done,
False,
False,
dict(
reward_linup=uph_cost,
reward_quadctrl=-quad_ctrl_cost,

View File

@@ -40,8 +40,8 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2
alive_bonus = 10
r = alive_bonus - dist_penalty - vel_penalty
done = bool(y <= 1)
return ob, r, done, {}
terminated = bool(y <= 1)
return ob, r, terminated, False, {}
def _get_obs(self):
return np.concatenate(

View File

@@ -85,12 +85,12 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
of [-0.1, 0.1] added to the positional values (cart position and pole angles) and standard
normal force with a standard deviation of 0.1 added to the velocity values for stochasticity.
### Episode Termination
The episode terminates when any of the following happens:
### Episode End
The episode ends when any of the following happens:
1. The episode duration reaches 1000 timesteps.
2. Any of the state space values is no longer finite.
3. The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other).
1.Truncation: The episode duration reaches 1000 timesteps.
2.Termination: Any of the state space values is no longer finite.
3.Termination: The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other).
### Arguments
@@ -143,11 +143,9 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2
alive_bonus = 10
r = alive_bonus - dist_penalty - vel_penalty
done = bool(y <= 1)
terminated = bool(y <= 1)
self.renderer.render_step()
return ob, r, done, {}
return ob, r, terminated, False, {}
def _get_obs(self):
return np.concatenate(

View File

@@ -35,9 +35,8 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
self.renderer.render_step()
ob = self._get_obs()
notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2)
done = not notdone
return ob, reward, done, {}
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
return ob, reward, terminated, False, {}
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(

View File

@@ -56,12 +56,12 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
(0.0, 0.0, 0.0, 0.0) with a uniform noise in the range
of [-0.01, 0.01] added to the values for stochasticity.
### Episode Termination
The episode terminates when any of the following happens:
### Episode End
The episode ends when any of the following happens:
1. The episode duration reaches 1000 timesteps.
2. Any of the state space values is no longer finite.
3. The absolute value of the vertical angle between the pole and the cart is greater than 0.2 radians.
1. Truncation: The episode duration reaches 1000 timesteps.
2. Termination: Any of the state space values is no longer finite.
3. Termination: The absolutely value of the vertical angle between the pole and the cart is greater than 0.2 radian.
### Arguments
@@ -109,12 +109,9 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
reward = 1.0
self.do_simulation(a, self.frame_skip)
ob = self._get_obs()
notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2)
done = not notdone
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
self.renderer.render_step()
return ob, reward, done, {}
return ob, reward, terminated, False, {}
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(

View File

@@ -38,8 +38,13 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
self.renderer.render_step()
ob = self._get_obs()
done = False
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
return (
ob,
reward,
False,
False,
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
)
def viewer_setup(self):
assert self.viewer is not None

View File

@@ -99,12 +99,12 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
The default framerate is 5 with each frame lasting for 0.01, giving rise to a *dt = 5 * 0.01 = 0.05*
### Episode Termination
### Episode End
The episode terminates when any of the following happens:
The episode ends when any of the following happens:
1. The episode duration reaches a 100 timesteps.
2. Any of the state space values is no longer finite.
1. Truncation: The episode duration reaches a 100 timesteps.
2. Termination: Any of the state space values is no longer finite.
### Arguments
@@ -157,11 +157,14 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip)
ob = self._get_obs()
done = False
self.renderer.render_step()
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
return (
ob,
reward,
False,
False,
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
)
def viewer_setup(self):
assert self.viewer is not None

View File

@@ -34,8 +34,13 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
self.renderer.render_step()
ob = self._get_obs()
done = False
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
return (
ob,
reward,
False,
False,
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
)
def viewer_setup(self):
assert self.viewer is not None

View File

@@ -89,12 +89,12 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
element ("fingertip" - "target") is calculated at the end once everything
is set. The default setting has a framerate of 2 and a *dt = 2 * 0.01 = 0.02*
### Episode Termination
### Episode End
The episode terminates when any of the following happens:
The episode ends when any of the following happens:
1. The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps)
2. Any of the state space values is no longer finite.
1. Truncation: The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps)
2. Termination: Any of the state space values is no longer finite.
### Arguments
@@ -143,11 +143,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
reward = reward_dist + reward_ctrl
self.do_simulation(a, self.frame_skip)
ob = self._get_obs()
done = False
self.renderer.render_step()
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
return (
ob,
reward,
False,
False,
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
)
def viewer_setup(self):
assert self.viewer is not None

View File

@@ -36,7 +36,13 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
reward = reward_fwd + reward_ctrl
ob = self._get_obs()
return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
return (
ob,
reward,
False,
False,
dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl),
)
def _get_obs(self):
qpos = self.sim.data.qpos

View File

@@ -73,7 +73,6 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
observation = self._get_obs()
reward = forward_reward - ctrl_cost
done = False
info = {
"reward_fwd": forward_reward,
"reward_ctrl": -ctrl_cost,
@@ -85,7 +84,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
"forward_reward": forward_reward,
}
return observation, reward, done, info
return observation, reward, False, False, info
def _get_obs(self):
position = self.sim.data.qpos.flat.copy()

View File

@@ -89,8 +89,8 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
### Starting State
All observations start in state (0,0,0,0,0,0,0,0) with a Uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] is added to the initial state for stochasticity.
### Episode Termination
The episode terminates when the episode length is greater than 1000.
### Episode End
The episode truncates when the episode length is greater than 1000.
### Arguments
@@ -183,7 +183,6 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
observation = self._get_obs()
reward = forward_reward - ctrl_cost
done = False
info = {
"reward_fwd": forward_reward,
"reward_ctrl": -ctrl_cost,
@@ -196,7 +195,7 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
}
self.renderer.render_step()
return observation, reward, done, info
return observation, reward, False, False, info
def _get_obs(self):
position = self.data.qpos.flat.copy()

View File

@@ -35,10 +35,9 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
reward = (posafter - posbefore) / self.dt
reward += alive_bonus
reward -= 1e-3 * np.square(a).sum()
done = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0)
terminated = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0)
ob = self._get_obs()
return ob, reward, done, {}
return ob, reward, terminated, False, {}
def _get_obs(self):
qpos = self.sim.data.qpos

View File

@@ -92,9 +92,9 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = not self.is_healthy if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
return terminated
def _get_obs(self):
position = self.sim.data.qpos.flat.copy()
@@ -123,13 +123,13 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
observation = self._get_obs()
reward = rewards - costs
done = self.done
terminated = self.terminated
info = {
"x_position": x_position_after,
"x_velocity": x_velocity,
}
return observation, reward, done, info
return observation, reward, terminated, False, info
def reset_model(self):
noise_low = -self._reset_noise_scale

View File

@@ -92,7 +92,7 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
(0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
with a uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity.
### Episode Termination
### Episode End
The walker is said to be unhealthy if any of the following happens:
1. Any of the state space values is no longer finite
@@ -100,12 +100,12 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
3. The absolute value of the angle (`observation[1]` if `exclude_current_positions_from_observation=False`, else `observation[2]`) is ***not*** in the closed interval specified by `healthy_angle_range`
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
the episode terminates when any of the following happens:
the episode ends when any of the following happens:
1. The episode duration reaches a 1000 timesteps
2. The walker is unhealthy
1. Truncation: The episode duration reaches a 1000 timesteps
2. Termination: The walker is unhealthy
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
### Arguments
@@ -221,9 +221,9 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
return is_healthy
@property
def done(self):
done = not self.is_healthy if self._terminate_when_unhealthy else False
return done
def terminated(self):
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
return terminated
def _get_obs(self):
position = self.data.qpos.flat.copy()
@@ -251,14 +251,14 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
observation = self._get_obs()
reward = rewards - costs
done = self.done
terminated = self.terminated
info = {
"x_position": x_position_after,
"x_velocity": x_velocity,
}
self.renderer.render_step()
return observation, reward, done, info
return observation, reward, terminated, False, info
def reset_model(self):
noise_low = -self._reset_noise_scale

View File

@@ -23,7 +23,13 @@ from typing import (
import numpy as np
from gym.envs.__relocated__ import internal_env_relocation_map
from gym.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit
from gym.wrappers import (
AutoResetWrapper,
HumanRendering,
OrderEnforcing,
StepAPICompatibility,
TimeLimit,
)
from gym.wrappers.env_checker import PassiveEnvChecker
if sys.version_info < (3, 10):
@@ -118,6 +124,7 @@ class EnvSpec:
max_episode_steps: Optional[int] = field(default=None)
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
new_step_api: bool = field(default=False)
kwargs: dict = field(default_factory=dict)
namespace: Optional[str] = field(init=False)
@@ -522,6 +529,7 @@ def make(
id: Union[str, EnvSpec],
max_episode_steps: Optional[int] = None,
autoreset: bool = False,
new_step_api: bool = False,
disable_env_checker: bool = False,
**kwargs,
) -> Env:
@@ -531,6 +539,7 @@ def make(
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
disable_env_checker: If to disable the environment checker
kwargs: Additional arguments to pass to the environment constructor.
@@ -644,19 +653,21 @@ def make(
if disable_env_checker is False:
env = PassiveEnvChecker(env)
env = StepAPICompatibility(env, new_step_api)
# Add the order enforcing wrapper
if spec_.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
env = TimeLimit(env, max_episode_steps, new_step_api)
elif spec_.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps)
env = TimeLimit(env, spec_.max_episode_steps, new_step_api)
# Add the autoreset wrapper
if autoreset:
env = AutoResetWrapper(env)
env = AutoResetWrapper(env, new_step_api)
return env

View File

@@ -137,13 +137,13 @@ class BlackjackEnv(gym.Env):
if action: # hit: add a card to players hand and return
self.player.append(draw_card(self.np_random))
if is_bust(self.player):
done = True
terminated = True
reward = -1.0
else:
done = False
terminated = False
reward = 0.0
else: # stick: play out the dealers hand, and score
done = True
terminated = True
while sum_hand(self.dealer) < 17:
self.dealer.append(draw_card(self.np_random))
reward = cmp(score(self.player), score(self.dealer))
@@ -158,9 +158,8 @@ class BlackjackEnv(gym.Env):
):
# Natural gives extra points, but doesn't autowin. Legacy implementation
reward = 1.5
self.renderer.render_step()
return self._get_obs(), reward, done, {}
return self._get_obs(), reward, terminated, False, {}
def _get_obs(self):
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))

View File

@@ -111,7 +111,7 @@ class CliffWalkingEnv(Env):
delta: Change in position for transition
Returns:
Tuple of ``(1.0, new_state, reward, done)``
Tuple of ``(1.0, new_state, reward, terminated)``
"""
new_position = np.array(current) + np.array(delta)
new_position = self._limit_coordinates(new_position).astype(int)
@@ -120,17 +120,17 @@ class CliffWalkingEnv(Env):
return [(1.0, self.start_state_index, -100, False)]
terminal_state = (self.shape[0] - 1, self.shape[1] - 1)
is_done = tuple(new_position) == terminal_state
return [(1.0, new_state, -1, is_done)]
is_terminated = tuple(new_position) == terminal_state
return [(1.0, new_state, -1, is_terminated)]
def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d = transitions[i]
p, s, r, t = transitions[i]
self.s = s
self.lastaction = a
self.renderer.render_step()
return (int(s), r, d, {"prob": p})
return (int(s), r, t, False, {"prob": p})
def reset(
self,

View File

@@ -201,9 +201,9 @@ class FrozenLakeEnv(Env):
newrow, newcol = inc(row, col, action)
newstate = to_s(newrow, newcol)
newletter = desc[newrow, newcol]
done = bytes(newletter) in b"GH"
terminated = bytes(newletter) in b"GH"
reward = float(newletter == b"G")
return newstate, reward, done
return newstate, reward, terminated
for row in range(nrow):
for col in range(ncol):
@@ -242,13 +242,11 @@ class FrozenLakeEnv(Env):
def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d = transitions[i]
p, s, r, t = transitions[i]
self.s = s
self.lastaction = a
self.renderer.render_step()
return (int(s), r, d, {"prob": p})
return (int(s), r, t, False, {"prob": p})
def reset(
self,

View File

@@ -156,7 +156,7 @@ class TaxiEnv(Env):
reward = (
-1
) # default reward when there is no pickup/dropoff
done = False
terminated = False
taxi_loc = (row, col)
if action == 0:
@@ -175,7 +175,7 @@ class TaxiEnv(Env):
elif action == 5: # dropoff
if (taxi_loc == locs[dest_idx]) and pass_idx == 4:
new_pass_idx = dest_idx
done = True
terminated = True
reward = 20
elif (taxi_loc in locs) and pass_idx == 4:
new_pass_idx = locs.index(taxi_loc)
@@ -184,7 +184,9 @@ class TaxiEnv(Env):
new_state = self.encode(
new_row, new_col, new_pass_idx, dest_idx
)
self.P[state][action].append((1.0, new_state, reward, done))
self.P[state][action].append(
(1.0, new_state, reward, terminated)
)
self.initial_state_distrib /= self.initial_state_distrib.sum()
self.action_space = spaces.Discrete(num_actions)
self.observation_space = spaces.Discrete(num_states)
@@ -254,12 +256,11 @@ class TaxiEnv(Env):
def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d = transitions[i]
p, s, r, t = transitions[i]
self.s = s
self.lastaction = a
self.renderer.render_step()
return int(s), r, d, {"prob": p, "action_mask": self.action_mask(s)}
return (int(s), r, t, False, {"prob": p, "action_mask": self.action_mask(s)})
def reset(
self,

View File

@@ -58,7 +58,7 @@ class ResetNeeded(Error):
class ResetNotAllowed(Error):
"""When the monitor is active, raised when the user tries to step an environment that's not yet done."""
"""When the monitor is active, raised when the user tries to step an environment that's not yet terminated or truncated."""
class InvalidAction(Error):

View File

@@ -4,6 +4,7 @@ import inspect
import numpy as np
from gym import error, logger, spaces
from gym.logger import deprecation
def _check_box_observation_space(observation_space: spaces.Box):
@@ -253,14 +254,24 @@ def passive_env_step_check(env, action):
"""A passive check for the environment step, investigating the returning data then returning the data unchanged."""
result = env.step(action)
if len(result) == 4:
deprecation(
"Core environment is written in old step API which returns one bool instead of two. "
"It is recommended to rewrite the environment with new step API. "
)
obs, reward, done, info = result
assert isinstance(done, bool), "The `done` signal must be a boolean"
assert isinstance(
done, bool
), f"The `done` signal is of type `{type(done)}` must be a boolean"
elif len(result) == 5:
obs, reward, terminated, truncated, info = result
assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
assert isinstance(
terminated, bool
), f"The `terminated` signal is of type `{type(terminated)}`. It must be a boolean"
assert isinstance(
truncated, bool
), f"The `truncated` signal of type `{type(truncated)}`. It must be a boolean."
assert (
terminated is False or truncated is False
), "Only `terminated` or `truncated` can be true, not both."

View File

@@ -1,4 +1,7 @@
"""Utilities of visualising an environment."""
# TODO: Convert to new step API in 1.0
from collections import deque
from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -208,6 +211,10 @@ def play(
seed: Random seed used when resetting the environment. If None, no seed is used.
noop: The action used when no key input has been entered, or the entered key combination is unknown.
"""
deprecation(
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
)
env.reset(seed=seed)
if keys_to_action is None:

View File

@@ -0,0 +1,180 @@
"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0."""
from typing import Tuple, Union
import numpy as np
from gym.core import ObsType
OldStepType = Tuple[
Union[ObsType, np.ndarray],
Union[float, np.ndarray],
Union[bool, np.ndarray],
Union[dict, list],
]
NewStepType = Tuple[
Union[ObsType, np.ndarray],
Union[float, np.ndarray],
Union[bool, np.ndarray],
Union[bool, np.ndarray],
Union[dict, list],
]
def step_to_new_api(
step_returns: Union[OldStepType, NewStepType], is_vector_env=False
) -> NewStepType:
"""Function to transform step returns to new step API irrespective of input API.
Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
is_vector_env (bool): Whether the step_returns are from a vector environment
"""
if len(step_returns) == 5:
return step_returns
else:
assert len(step_returns) == 4
observations, rewards, dones, infos = step_returns
terminateds = []
truncateds = []
if not is_vector_env:
dones = [dones]
for i in range(len(dones)):
# For every condition, handling - info single env / info vector env (list) / info vector env (dict)
# TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done`
if (not is_vector_env and "TimeLimit.truncated" not in infos) or (
is_vector_env
and (
(
isinstance(infos, list)
and "TimeLimit.truncated" not in infos[i]
) # vector env, list info api
or (
"TimeLimit.truncated" not in infos
or (
"TimeLimit.truncated" in infos
and not infos["_TimeLimit.truncated"][i]
)
) # vector env, dict info api, if mask is False, it's the same as TimeLimit.truncated attribute not being present for env 'i'
)
):
terminateds.append(dones[i])
truncateds.append(False)
# This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not.
elif (
infos["TimeLimit.truncated"]
if not is_vector_env
else (
infos["TimeLimit.truncated"][i]
if isinstance(infos, dict)
else infos[i]["TimeLimit.truncated"]
)
):
assert dones[i]
terminateds.append(False)
truncateds.append(True)
else:
# This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated,
# but it also exceeded maximum timesteps at the same step.
assert dones[i]
terminateds.append(True)
truncateds.append(True)
return (
observations,
rewards,
np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0],
np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0],
infos,
)
def step_to_old_api(
step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False
) -> OldStepType:
"""Function to transform step returns to old step API irrespective of input API.
Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
is_vector_env (bool): Whether the step_returns are from a vector environment
"""
if len(step_returns) == 4:
return step_returns
else:
assert len(step_returns) == 5
observations, rewards, terminateds, truncateds, infos = step_returns
dones = []
if not is_vector_env:
terminateds = [terminateds]
truncateds = [truncateds]
n_envs = len(terminateds)
for i in range(n_envs):
dones.append(terminateds[i] or truncateds[i])
if truncateds[i]:
if is_vector_env:
# handle vector infos for dict and list
if isinstance(infos, dict):
if "TimeLimit.truncated" not in infos:
# TODO: This should ideally not be done manually and should use vector_env's _add_info()
infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
infos["TimeLimit.truncated"][i] = (
not terminateds[i] or infos["TimeLimit.truncated"][i]
)
infos["_TimeLimit.truncated"][i] = True
else:
# if vector info is a list
infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[
i
].get("TimeLimit.truncated", False)
else:
infos["TimeLimit.truncated"] = not terminateds[i] or infos.get(
"TimeLimit.truncated", False
)
return (
observations,
rewards,
np.array(dones, dtype=np.bool_) if is_vector_env else dones[0],
infos,
)
def step_api_compatibility(
step_returns: Union[NewStepType, OldStepType],
new_step_api: bool = False,
is_vector_env: bool = False,
) -> Union[NewStepType, OldStepType]:
"""Function to transform step returns to the API specified by `new_step_api` bool.
Old step API refers to step() method returning (observation, reward, done, info)
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
(Refer to docs for details on the API change)
Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
new_step_api (bool): Whether the output should be in new step API or old (False by default)
is_vector_env (bool): Whether the step_returns are from a vector environment
Returns:
step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
Examples:
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
wrapper is written in new API, and the final step output is desired to be in old API.
>>> obs, rew, done, info = step_api_compatibility(env.step(action))
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True)
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
"""
if new_step_api:
return step_to_new_api(step_returns, is_vector_env)
else:
return step_to_old_api(step_returns, is_vector_env)

View File

@@ -15,6 +15,7 @@ def make(
asynchronous: bool = True,
wrappers: Optional[Union[callable, List[callable]]] = None,
disable_env_checker: bool = False,
new_step_api: bool = False,
**kwargs,
) -> VectorEnv:
"""Create a vectorized environment from multiple copies of an environment, from its id.
@@ -35,6 +36,7 @@ def make(
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
disable_env_checker: If to disable the env checker, if True it will only run on the first environment created.
new_step_api: If True, the vector environment's step method outputs two booleans `terminated`, `truncated` instead of one `done`.
**kwargs: Keywords arguments applied during gym.make
Returns:
@@ -46,7 +48,10 @@ def make(
def _make_env():
env = gym.envs.registration.make(
id, disable_env_checker=_disable_env_checker, **kwargs
id,
disable_env_checker=_disable_env_checker,
new_step_api=True,
**kwargs,
)
if wrappers is not None:
if callable(wrappers):
@@ -65,4 +70,8 @@ def make(
env_fns = [
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
]
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
return (
AsyncVectorEnv(env_fns, new_step_api=new_step_api)
if asynchronous
else SyncVectorEnv(env_fns, new_step_api=new_step_api)
)

View File

@@ -17,6 +17,7 @@ from gym.error import (
CustomSpaceError,
NoAsyncCallError,
)
from gym.utils.step_api_compatibility import step_api_compatibility
from gym.vector.utils import (
CloudpickleWrapper,
clear_mpi_env_vars,
@@ -66,6 +67,7 @@ class AsyncVectorEnv(VectorEnv):
context: Optional[str] = None,
daemon: bool = True,
worker: Optional[callable] = None,
new_step_api: bool = False,
):
"""Vectorized environment that runs multiple environments in parallel.
@@ -84,7 +86,8 @@ class AsyncVectorEnv(VectorEnv):
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
@@ -112,6 +115,7 @@ class AsyncVectorEnv(VectorEnv):
num_envs=len(env_fns),
observation_space=observation_space,
action_space=action_space,
new_step_api=new_step_api,
)
if self.shared_memory:
@@ -338,7 +342,7 @@ class AsyncVectorEnv(VectorEnv):
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
Returns:
The batched environment step information, obs, reward, done and info
The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
@@ -358,16 +362,17 @@ class AsyncVectorEnv(VectorEnv):
f"The call to `step_wait` has timed out after {timeout} second(s)."
)
observations_list, rewards, dones, infos = [], [], [], {}
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, done, info = result
obs, rew, terminated, truncated, info = step_api_compatibility(result, True)
successes.append(success)
observations_list.append(obs)
rewards.append(rew)
dones.append(done)
terminateds.append(terminated)
truncateds.append(truncated)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes)
@@ -380,11 +385,16 @@ class AsyncVectorEnv(VectorEnv):
self.observations,
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(dones, dtype=np.bool_),
infos,
return step_api_compatibility(
(
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
),
self.new_step_api,
True,
)
def call_async(self, name: str, *args, **kwargs):
@@ -604,11 +614,17 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
pipe.send((observation, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:
info["terminal_observation"] = observation
(
observation,
reward,
terminated,
truncated,
info,
) = step_api_compatibility(env.step(data), True)
if terminated or truncated:
info["final_observation"] = observation
observation = env.reset()
pipe.send(((observation, reward, done, info), True))
pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))
@@ -673,14 +689,20 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
)
pipe.send((None, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:
info["terminal_observation"] = observation
(
observation,
reward,
terminated,
truncated,
info,
) = step_api_compatibility(env.step(data), True)
if terminated or truncated:
info["final_observation"] = observation
observation = env.reset()
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, reward, done, info), True))
pipe.send(((None, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))

View File

@@ -6,6 +6,7 @@ import numpy as np
from gym import Env
from gym.spaces import Space
from gym.utils.step_api_compatibility import step_api_compatibility
from gym.vector.utils import concatenate, create_empty_array, iterate
from gym.vector.vector_env import VectorEnv
@@ -33,6 +34,7 @@ class SyncVectorEnv(VectorEnv):
observation_space: Space = None,
action_space: Space = None,
copy: bool = True,
new_step_api: bool = False,
):
"""Vectorized environment that serially runs multiple environments.
@@ -60,6 +62,7 @@ class SyncVectorEnv(VectorEnv):
num_envs=len(self.envs),
observation_space=observation_space,
action_space=action_space,
new_step_api=new_step_api,
)
self._check_spaces()
@@ -67,7 +70,8 @@ class SyncVectorEnv(VectorEnv):
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
@@ -108,7 +112,8 @@ class SyncVectorEnv(VectorEnv):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
self._dones[:] = False
self._terminateds[:] = False
self._truncateds[:] = False
observations = []
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
@@ -151,9 +156,15 @@ class SyncVectorEnv(VectorEnv):
"""
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]:
info["terminal_observation"] = observation
(
observation,
self._rewards[i],
self._terminateds[i],
self._truncateds[i],
info,
) = step_api_compatibility(env.step(action), True)
if self._terminateds[i] or self._truncateds[i]:
info["final_observation"] = observation
observation = env.reset()
observations.append(observation)
infos = self._add_info(infos, info, i)
@@ -161,11 +172,16 @@ class SyncVectorEnv(VectorEnv):
self.single_observation_space, observations, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._dones),
infos,
return step_api_compatibility(
(
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._terminateds),
np.copy(self._truncateds),
infos,
),
new_step_api=self.new_step_api,
is_vector_env=True,
)
def call(self, name, *args, **kwargs) -> tuple:

View File

@@ -1,5 +1,5 @@
"""Base class for vectorized environments."""
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
@@ -24,7 +24,11 @@ class VectorEnv(gym.Env):
"""
def __init__(
self, num_envs: int, observation_space: gym.Space, action_space: gym.Space
self,
num_envs: int,
observation_space: gym.Space,
action_space: gym.Space,
new_step_api: bool = False,
):
"""Base class for vectorized environments.
@@ -32,6 +36,7 @@ class VectorEnv(gym.Env):
num_envs: Number of environments in the vectorized environment.
observation_space: Observation space of a single environment.
action_space: Action space of a single environment.
new_step_api (bool): Whether the vector env's step method outputs two boolean arrays (new API) or one boolean array (old API)
"""
self.num_envs = num_envs
self.is_vector_env = True
@@ -46,6 +51,13 @@ class VectorEnv(gym.Env):
self.single_observation_space = observation_space
self.single_action_space = action_space
self.new_step_api = new_step_api
if not self.new_step_api:
deprecation(
"Initializing vector env in old step API which returns one bool array instead of two. "
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
)
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
@@ -135,7 +147,7 @@ class VectorEnv(gym.Env):
actions: element of :attr:`action_space` Batch of actions.
Returns:
Batch of observations, rewards, done and infos
Batch of (observations, rewards, terminateds, truncateds, infos) or (observations, rewards, dones, infos)
"""
self.step_async(actions)
return self.step_wait()
@@ -143,7 +155,7 @@ class VectorEnv(gym.Env):
def call_async(self, name, *args, **kwargs):
"""Calls a method name for each parallel environment asynchronously."""
def call_wait(self, **kwargs) -> List[Any]:
def call_wait(self, **kwargs) -> List[Any]: # type: ignore
"""After calling a method in :meth:`call_async`, this function collects the results."""
def call(self, name: str, *args, **kwargs) -> List[Any]:
@@ -251,7 +263,7 @@ class VectorEnv(gym.Env):
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> np.ndarray:
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric

View File

@@ -14,6 +14,7 @@ from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.step_api_compatibility import StepAPICompatibility
from gym.wrappers.time_aware_observation import TimeAwareObservation
from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.transform_observation import TransformObservation

View File

@@ -3,6 +3,7 @@ import numpy as np
import gym
from gym.spaces import Box
from gym.utils.step_api_compatibility import step_api_compatibility
try:
import cv2
@@ -37,6 +38,7 @@ class AtariPreprocessing(gym.Wrapper):
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
new_step_api: bool = False,
):
"""Wrapper for Atari 2600 preprocessing.
@@ -45,7 +47,7 @@ class AtariPreprocessing(gym.Wrapper):
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `done=True` whenever a
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
@@ -58,7 +60,7 @@ class AtariPreprocessing(gym.Wrapper):
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
"""
super().__init__(env)
super().__init__(env, new_step_api)
if cv2 is None:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
@@ -114,20 +116,22 @@ class AtariPreprocessing(gym.Wrapper):
def step(self, action):
"""Applies the preprocessing for an :meth:`env.step`."""
total_reward, done, info = 0.0, False, {}
total_reward, terminated, truncated, info = 0.0, False, False, {}
for t in range(self.frame_skip):
_, reward, done, info = self.env.step(action)
_, reward, terminated, truncated, info = step_api_compatibility(
self.env.step(action), True
)
total_reward += reward
self.game_over = done
self.game_over = terminated
if self.terminal_on_life_loss:
new_lives = self.ale.lives()
done = done or new_lives < self.lives
self.game_over = done
terminated = terminated or new_lives < self.lives
self.game_over = terminated
self.lives = new_lives
if done:
if terminated or truncated:
break
if t == self.frame_skip - 2:
if self.grayscale_obs:
@@ -139,7 +143,10 @@ class AtariPreprocessing(gym.Wrapper):
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), total_reward, done, info
return step_api_compatibility(
(self._get_obs(), total_reward, terminated, truncated, info),
self.new_step_api,
)
def reset(self, **kwargs):
"""Resets the environment using preprocessing."""
@@ -156,9 +163,11 @@ class AtariPreprocessing(gym.Wrapper):
else 0
)
for _ in range(noops):
_, _, done, step_info = self.env.step(0)
_, _, terminated, truncated, step_info = step_api_compatibility(
self.env.step(0), True
)
reset_info.update(step_info)
if done:
if terminated or truncated:
if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:

View File

@@ -1,29 +1,40 @@
"""Wrapper that autoreset environments when `done=True`."""
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
import gym
from gym.utils.step_api_compatibility import step_api_compatibility
class AutoResetWrapper(gym.Wrapper):
"""A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
When calling step causes :meth:`Env.step` to return done, :meth:`Env.reset` is called,
and the return format of :meth:`self.step` is as follows: ``(new_obs, terminal_reward, terminal_done, info)``
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
- ``new_obs`` is the first observation after calling :meth:`self.env.reset`
- ``terminal_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
- ``terminal_done`` is always True
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
- ``final_done`` is always True. In the new API, either ``final_terminated`` or ``final_truncated`` is True
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
with an additional key "terminal_observation" containing the observation returned by the last call to :meth:`self.env.step`
and "terminal_info" containing the info dict returned by the last call to :meth:`self.env.step`.
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns done, a
Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
terminal reward and done state from the previous episode.
If you need the terminal state from the previous episode, you need to retrieve it via the
"terminal_observation" key in the info dict.
final reward and done state from the previous episode.
If you need the final state from the previous episode, you need to retrieve it via the
"final_observation" key in the info dict.
Make sure you know what you're doing if you use this wrapper!
"""
def __init__(self, env: gym.Env, new_step_api: bool = False):
"""A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env, new_step_api)
def step(self, action):
"""Steps through the environment with action and resets the environment if a done-signal is encountered.
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
Args:
action: The action to take
@@ -31,22 +42,26 @@ class AutoResetWrapper(gym.Wrapper):
Returns:
The autoreset environment :meth:`step`
"""
obs, reward, done, info = self.env.step(action)
obs, reward, terminated, truncated, info = step_api_compatibility(
self.env.step(action), True
)
if done:
if terminated or truncated:
new_obs, new_info = self.env.reset(return_info=True)
assert (
"terminal_observation" not in new_info
), 'info dict cannot contain key "terminal_observation" '
"final_observation" not in new_info
), 'info dict cannot contain key "final_observation" '
assert (
"terminal_info" not in new_info
), 'info dict cannot contain key "terminal_info" '
"final_info" not in new_info
), 'info dict cannot contain key "final_info" '
new_info["terminal_observation"] = obs
new_info["terminal_info"] = info
new_info["final_observation"] = obs
new_info["final_info"] = info
obs = new_obs
info = new_info
return obs, reward, done, info
return step_api_compatibility(
(obs, reward, terminated, truncated, info), self.new_step_api
)

View File

@@ -26,7 +26,7 @@ class ClipAction(ActionWrapper):
env: The environment to apply the wrapper
"""
assert isinstance(env.action_space, Box)
super().__init__(env)
super().__init__(env, new_step_api=True)
def action(self, action):
"""Clips the action within the valid bounds.

View File

@@ -15,7 +15,7 @@ class PassiveEnvChecker(gym.Wrapper):
def __init__(self, env):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
super().__init__(env)
super().__init__(env, new_step_api=True)
assert hasattr(
env, "action_space"

View File

@@ -35,7 +35,7 @@ class FilterObservation(gym.ObservationWrapper):
ValueError: If the environment's observation space is not :class:`spaces.Dict`
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
wrapped_observation_space = env.observation_space
if not isinstance(wrapped_observation_space, spaces.Dict):

View File

@@ -25,7 +25,7 @@ class FlattenObservation(gym.ObservationWrapper):
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
self.observation_space = spaces.flatten_space(env.observation_space)
def observation(self, observation):

View File

@@ -7,6 +7,7 @@ import numpy as np
import gym
from gym.error import DependencyNotInstalled
from gym.spaces import Box
from gym.utils.step_api_compatibility import step_api_compatibility
class LazyFrames:
@@ -122,15 +123,22 @@ class FrameStack(gym.ObservationWrapper):
(4, 96, 96, 3)
"""
def __init__(self, env: gym.Env, num_stack: int, lz4_compress: bool = False):
def __init__(
self,
env: gym.Env,
num_stack: int,
lz4_compress: bool = False,
new_step_api: bool = False,
):
"""Observation wrapper that stacks the observations in a rolling manner.
Args:
env (Env): The environment to apply the wrapper
num_stack (int): The number of frames to stack
lz4_compress (bool): Use lz4 to compress the frames internally
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
super().__init__(env, new_step_api)
self.num_stack = num_stack
self.lz4_compress = lz4_compress
@@ -163,11 +171,15 @@ class FrameStack(gym.ObservationWrapper):
action: The action to step through the environment with
Returns:
Stacked observations, reward, done and information from the environment
Stacked observations, reward, terminated, truncated, and information from the environment
"""
observation, reward, done, info = self.env.step(action)
observation, reward, terminated, truncated, info = step_api_compatibility(
self.env.step(action), True
)
self.frames.append(observation)
return self.observation(None), reward, done, info
return step_api_compatibility(
(self.observation(), reward, terminated, truncated, info), self.new_step_api
)
def reset(self, **kwargs):
"""Reset the environment with kwargs.

View File

@@ -28,7 +28,7 @@ class GrayScaleObservation(gym.ObservationWrapper):
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
Otherwise, they are of shape AxB.
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
self.keep_dim = keep_dim
assert (

View File

@@ -45,7 +45,7 @@ class HumanRendering(gym.Wrapper):
Args:
env: The environment that is being wrapped
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
assert env.render_mode in [
"single_rgb_array",
"rgb_array",

View File

@@ -2,6 +2,7 @@
import numpy as np
import gym
from gym.utils.step_api_compatibility import step_api_compatibility
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
@@ -54,14 +55,15 @@ class NormalizeObservation(gym.core.Wrapper):
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
def __init__(self, env: gym.Env, epsilon: float = 1e-8, new_step_api: bool = False):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
super().__init__(env, new_step_api)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
if self.is_vector_env:
@@ -72,12 +74,18 @@ class NormalizeObservation(gym.core.Wrapper):
def step(self, action):
"""Steps through the environment and normalizes the observation."""
obs, rews, dones, infos = self.env.step(action)
obs, rews, terminateds, truncateds, infos = step_api_compatibility(
self.env.step(action), True, self.is_vector_env
)
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs, rews, dones, infos
return step_api_compatibility(
(obs, rews, terminateds, truncateds, infos),
self.new_step_api,
self.is_vector_env,
)
def reset(self, **kwargs):
"""Resets the environment and normalizes the observation."""
@@ -117,6 +125,7 @@ class NormalizeReward(gym.core.Wrapper):
env: gym.Env,
gamma: float = 0.99,
epsilon: float = 1e-8,
new_step_api: bool = False,
):
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
@@ -124,8 +133,9 @@ class NormalizeReward(gym.core.Wrapper):
env (env): The environment to apply the wrapper
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
super().__init__(env, new_step_api)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
self.return_rms = RunningMeanStd(shape=())
@@ -135,15 +145,25 @@ class NormalizeReward(gym.core.Wrapper):
def step(self, action):
"""Steps through the environment, normalizing the rewards returned."""
obs, rews, dones, infos = self.env.step(action)
obs, rews, terminateds, truncateds, infos = step_api_compatibility(
self.env.step(action), True, self.is_vector_env
)
if not self.is_vector_env:
rews = np.array([rews])
self.returns = self.returns * self.gamma + rews
rews = self.normalize(rews)
if not self.is_vector_env:
dones = terminateds or truncateds
else:
dones = np.bitwise_or(terminateds, truncateds)
self.returns[dones] = 0.0
if not self.is_vector_env:
rews = rews[0]
return obs, rews, dones, infos
return step_api_compatibility(
(obs, rews, terminateds, truncateds, infos),
self.new_step_api,
self.is_vector_env,
)
def normalize(self, rews):
"""Normalizes the rewards with the running mean rewards and their variance."""

View File

@@ -26,7 +26,7 @@ class OrderEnforcing(gym.Wrapper):
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing

View File

@@ -77,7 +77,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
specified ``pixel_keys``.
TypeError: When an unexpected pixel type is used
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
# Avoid side-effects that occur when render_kwargs is manipulated
render_kwargs = copy.deepcopy(render_kwargs)

View File

@@ -6,6 +6,7 @@ from typing import Optional
import numpy as np
import gym
from gym.utils.step_api_compatibility import step_api_compatibility
def add_vector_episode_statistics(
@@ -76,14 +77,15 @@ class RecordEpisodeStatistics(gym.Wrapper):
length_queue: The lengths of the last ``deque_size``-many episodes
"""
def __init__(self, env: gym.Env, deque_size: int = 100):
def __init__(self, env: gym.Env, deque_size: int = 100, new_step_api: bool = False):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
super().__init__(env, new_step_api)
self.num_envs = getattr(env, "num_envs", 1)
self.t0 = time.perf_counter()
self.episode_count = 0
@@ -102,18 +104,26 @@ class RecordEpisodeStatistics(gym.Wrapper):
def step(self, action):
"""Steps through the environment, recording the episode statistics."""
observations, rewards, dones, infos = super().step(action)
(
observations,
rewards,
terminateds,
truncateds,
infos,
) = step_api_compatibility(self.env.step(action), True, self.is_vector_env)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
if not self.is_vector_env:
dones = [dones]
dones = list(dones)
terminateds = [terminateds]
truncateds = [truncateds]
terminateds = list(terminateds)
truncateds = list(truncateds)
for i in range(len(dones)):
if dones[i]:
for i in range(len(terminateds)):
if terminateds[i] or truncateds[i]:
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
@@ -134,9 +144,14 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
return (
observations,
rewards,
dones if self.is_vector_env else dones[0],
infos,
return step_api_compatibility(
(
observations,
rewards,
terminateds if self.is_vector_env else terminateds[0],
truncateds if self.is_vector_env else truncateds[0],
infos,
),
self.new_step_api,
self.is_vector_env,
)

View File

@@ -4,6 +4,7 @@ from typing import Callable, Optional
import gym
from gym import logger
from gym.utils.step_api_compatibility import step_api_compatibility
from gym.wrappers.monitoring import video_recorder
@@ -32,7 +33,7 @@ class RecordVideo(gym.Wrapper):
They should be functions returning a boolean that indicates whether a recording should be started at the
current episode or step, respectively.
If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed.
By default, the recording will be stopped once a `done` signal has been emitted by the environment. However, you can
By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can
also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for
``video_length``.
"""
@@ -45,6 +46,7 @@ class RecordVideo(gym.Wrapper):
step_trigger: Callable[[int], bool] = None,
video_length: int = 0,
name_prefix: str = "rl-video",
new_step_api: bool = False,
):
"""Wrapper records videos of rollouts.
@@ -56,8 +58,9 @@ class RecordVideo(gym.Wrapper):
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
super().__init__(env, new_step_api)
if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule
@@ -83,7 +86,8 @@ class RecordVideo(gym.Wrapper):
self.video_length = video_length
self.recording = False
self.done = False
self.terminated = False
self.truncated = False
self.recorded_frames = 0
self.is_vector_env = getattr(env, "is_vector_env", False)
self.episode_id = 0
@@ -91,7 +95,8 @@ class RecordVideo(gym.Wrapper):
def reset(self, **kwargs):
"""Reset the environment using kwargs and then starts recording if video enabled."""
observations = super().reset(**kwargs)
self.done = False
self.terminated = False
self.truncated = False
if self.recording:
assert self.video_recorder is not None
self.video_recorder.frames = []
@@ -132,18 +137,26 @@ class RecordVideo(gym.Wrapper):
def step(self, action):
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
observations, rewards, dones, infos = super().step(action)
(
observations,
rewards,
terminateds,
truncateds,
infos,
) = step_api_compatibility(self.env.step(action), True, self.is_vector_env)
if not self.done:
if not (self.terminated or self.truncated):
# increment steps and episodes
self.step_id += 1
if not self.is_vector_env:
if dones:
if terminateds or truncateds:
self.episode_id += 1
self.done = True
elif dones[0]:
self.terminated = terminateds
self.truncated = truncateds
elif terminateds[0] or truncateds[0]:
self.episode_id += 1
self.done = True
self.terminated = terminateds[0]
self.truncated = truncateds[0]
if self.recording:
assert self.video_recorder is not None
@@ -154,15 +167,19 @@ class RecordVideo(gym.Wrapper):
self.close_video_recorder()
else:
if not self.is_vector_env:
if dones:
if terminateds or truncateds:
self.close_video_recorder()
elif dones[0]:
elif terminateds[0] or truncateds[0]:
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return observations, rewards, dones, infos
return step_api_compatibility(
(observations, rewards, terminateds, truncateds, infos),
self.new_step_api,
self.is_vector_env,
)
def close_video_recorder(self):
"""Closes the video recorder if currently recording."""

View File

@@ -45,7 +45,7 @@ class RescaleAction(gym.ActionWrapper):
), f"expected Box action space, got {type(env.action_space)}"
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
super().__init__(env)
super().__init__(env, new_step_api=True)
self.min_action = (
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
)

View File

@@ -32,7 +32,7 @@ class ResizeObservation(gym.ObservationWrapper):
env: The environment to apply the wrapper
shape: The shape of the resized observations
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
if isinstance(shape, int):
shape = (shape, shape)
assert all(x > 0 for x in shape), shape

View File

@@ -0,0 +1,57 @@
"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
import gym
from gym.logger import deprecation
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
class StepAPICompatibility(gym.Wrapper):
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
Old step API refers to step() method returning (observation, reward, done, info)
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
(Refer to docs for details on the API change)
This wrapper is to be used to ease transition to new API and for backward compatibility.
Args:
env (gym.Env): the env to wrap. Can be in old or new API
new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default)
Examples:
>>> env = gym.make("CartPole-v1")
>>> env # wrapper applied by default, set to old API
<TimeLimit<OrderEnforcing<StepAPICompatibility<CartPoleEnv<CartPole-v1>>>>>
>>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API
>>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs
"""
def __init__(self, env: gym.Env, new_step_api=False):
"""A wrapper which can transform an environment from new step API to old and vice-versa.
Args:
env (gym.Env): the env to wrap. Can be in old or new API
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env, new_step_api)
self.new_step_api = new_step_api
if not self.new_step_api:
deprecation(
"Initializing environment in old step API which returns one bool instead of two. "
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
)
def step(self, action):
"""Steps through the environment, returning 5 or 4 items depending on `new_step_api`.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info) or (observation, reward, done, info)
"""
step_returns = self.env.step(action)
if self.new_step_api:
return step_to_new_api(step_returns)
else:
return step_to_old_api(step_returns)

View File

@@ -3,6 +3,7 @@ import numpy as np
import gym
from gym.spaces import Box
from gym.utils.step_api_compatibility import step_api_compatibility
class TimeAwareObservation(gym.ObservationWrapper):
@@ -21,18 +22,20 @@ class TimeAwareObservation(gym.ObservationWrapper):
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
"""
def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env, new_step_api: bool = False):
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space.
Args:
env: The environment to apply the wrapper
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
super().__init__(env, new_step_api)
assert isinstance(env.observation_space, Box)
assert env.observation_space.dtype == np.float32
low = np.append(self.observation_space.low, 0.0)
high = np.append(self.observation_space.high, np.inf)
self.observation_space = Box(low, high, dtype=np.float32)
self.is_vector_env = getattr(env, "is_vector_env", False)
def observation(self, observation):
"""Adds to the observation with the current time step.
@@ -55,7 +58,9 @@ class TimeAwareObservation(gym.ObservationWrapper):
The environment's step using the action.
"""
self.t += 1
return super().step(action)
return step_api_compatibility(
super().step(action), self.new_step_api, self.is_vector_env
)
def reset(self, **kwargs):
"""Reset the environment setting the time to zero.

View File

@@ -2,16 +2,20 @@
from typing import Optional
import gym
from gym.utils.step_api_compatibility import step_api_compatibility
class TimeLimit(gym.Wrapper):
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
(deprecated)
This information is passed through ``info`` that is returned when `done`-signal was issued.
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. This will be removed in favour
of only issuing a `truncated` signal in future versions.
Example:
>>> from gym.envs.classic_control import CartPoleEnv
@@ -20,14 +24,20 @@ class TimeLimit(gym.Wrapper):
>>> env = TimeLimit(env, max_episode_steps=1000)
"""
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
def __init__(
self,
env: gym.Env,
max_episode_steps: Optional[int] = None,
new_step_api: bool = False,
):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
super().__init__(env, new_step_api)
if max_episode_steps is None and self.env.spec is not None:
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
@@ -46,15 +56,19 @@ class TimeLimit(gym.Wrapper):
when truncated (the number of steps elapsed >= max episode steps) or
"TimeLimit.truncated"=False if the environment terminated
"""
observation, reward, done, info = self.env.step(action)
observation, reward, terminated, truncated, info = step_api_compatibility(
self.env.step(action),
True,
)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
# TimeLimit.truncated key may have been already set by the environment
# do not overwrite it
episode_truncated = not done or info.get("TimeLimit.truncated", False)
info["TimeLimit.truncated"] = episode_truncated
done = True
return observation, reward, done, info
truncated = True
return step_api_compatibility(
(observation, reward, terminated, truncated, info),
self.new_step_api,
)
def reset(self, **kwargs):
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.

View File

@@ -27,7 +27,7 @@ class TransformObservation(gym.ObservationWrapper):
env: The environment to apply the wrapper
f: A function that transforms the observation
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
assert callable(f)
self.f = f

View File

@@ -16,7 +16,7 @@ class TransformReward(RewardWrapper):
>>> env = gym.make('CartPole-v1')
>>> env = TransformReward(env, lambda r: 0.01*r)
>>> env.reset()
>>> observation, reward, done, info = env.step(env.action_space.sample())
>>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
>>> reward
0.01
"""
@@ -28,7 +28,7 @@ class TransformReward(RewardWrapper):
env: The environment to apply the wrapper
f: A function that transforms the reward
"""
super().__init__(env)
super().__init__(env, new_step_api=True)
assert callable(f)
self.f = f

View File

@@ -3,6 +3,7 @@
from typing import List
import gym
from gym.utils.step_api_compatibility import step_api_compatibility
class VectorListInfo(gym.Wrapper):
@@ -29,23 +30,30 @@ class VectorListInfo(gym.Wrapper):
"""
def __init__(self, env):
def __init__(self, env, new_step_api=False):
"""This wrapper will convert the info into the list format.
Args:
env (Env): The environment to apply the wrapper
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
assert getattr(
env, "is_vector_env", False
), "This wrapper can only be used in vectorized environments."
super().__init__(env)
super().__init__(env, new_step_api)
def step(self, action):
"""Steps through the environment, convert dict info to list."""
observation, reward, done, infos = self.env.step(action)
observation, reward, terminated, truncated, infos = step_api_compatibility(
self.env.step(action), True, True
)
list_info = self._convert_info_to_list(infos)
return observation, reward, done, list_info
return step_api_compatibility(
(observation, reward, terminated, truncated, list_info),
self.new_step_api,
True,
)
def reset(self, **kwargs):
"""Resets the environment using kwargs."""

View File

@@ -34,5 +34,8 @@ reportPrivateUsage = "warning"
reportUntypedFunctionDecorator = "none"
reportMissingTypeStubs = false
reportUnboundVariable = "warning"
reportGeneralTypeIssues = "none"
reportInvalidTypeVarUse = "none"
reportGeneralTypeIssues ="none"
reportInvalidTypeVarUse = "none"
[tool.pytest.ini_options]
filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed

View File

@@ -112,7 +112,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
zip(env.action_space.bounded_above, env.action_space.bounded_below)
):
if is_upper_bound:
obs, _, _, _ = env.step(upper_bounds)
obs, _, _, _, _ = env.step(
upper_bounds
) # `env` is unwrapped, and in new step API
oob_action = upper_bounds.copy()
oob_action[i] += np.cast[dtype](OOB_VALUE)
@@ -122,7 +124,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
assert np.alltrue(obs == oob_obs)
if is_lower_bound:
obs, _, _, _ = env.step(lower_bounds)
obs, _, _, _, _ = env.step(
lower_bounds
) # `env` is unwrapped, and in new step API
oob_action = lower_bounds.copy()
oob_action[i] -= np.cast[dtype](OOB_VALUE)

View File

@@ -144,7 +144,7 @@ def test_taxi_encode_decode():
assert (
env.encode(*env.decode(state)) == state
), f"state={state}, encode(decode(state))={env.encode(*env.decode(state))}"
state, _, _, _ = env.step(env.action_space.sample())
state, _, _, _, _ = env.step(env.action_space.sample())
@pytest.mark.parametrize(

View File

@@ -172,7 +172,9 @@ def test_make_render_mode():
assert env.render_mode == valid_render_modes[0]
env.close()
assert len(warnings) == 0
for warning in warnings.list:
if not re.compile(".*step API.*").match(warning.message.args[0]):
raise gym.error.Error(f"Unexpected warning: {warning.message}")
# Make sure that native rendering is used when possible
env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True)

View File

@@ -0,0 +1,91 @@
import pytest
import gym
from gym.spaces import Discrete
from gym.vector import AsyncVectorEnv, SyncVectorEnv
from gym.wrappers import TimeLimit
# An environment where termination happens after 20 steps
class DummyEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
self.terminal_timestep = 20
self.timestep = 0
def step(self, action):
self.timestep += 1
terminated = True if self.timestep >= self.terminal_timestep else False
truncated = False
return 0, 0, terminated, truncated, {}
def reset(self):
self.timestep = 0
return 0
@pytest.mark.parametrize("time_limit", [10, 20, 30])
def test_terminated_truncated(time_limit):
test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True)
terminated = False
truncated = False
test_env.reset()
while not (terminated or truncated):
_, _, terminated, truncated, _ = test_env.step(0)
if test_env.terminal_timestep < time_limit:
assert terminated
assert not truncated
elif test_env.terminal_timestep == time_limit:
assert (
terminated
), "`terminated` should be True even when termination and truncation happen at the same step"
assert (
truncated
), "`truncated` should be True even when termination and truncation occur at same step "
else:
assert not terminated
assert truncated
def test_terminated_truncated_vector():
env0 = TimeLimit(DummyEnv(), 10, new_step_api=True)
env1 = TimeLimit(DummyEnv(), 20, new_step_api=True)
env2 = TimeLimit(DummyEnv(), 30, new_step_api=True)
async_env = AsyncVectorEnv(
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
)
async_env.reset()
terminateds = [False, False, False]
truncateds = [False, False, False]
counter = 0
while not all([x or y for x, y in zip(terminateds, truncateds)]):
counter += 1
_, _, terminateds, truncateds, _ = async_env.step(
async_env.action_space.sample()
)
print(counter)
assert counter == 20
assert all(terminateds == [False, True, True])
assert all(truncateds == [True, True, False])
sync_env = SyncVectorEnv(
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
)
sync_env.reset()
terminateds = [False, False, False]
truncateds = [False, False, False]
counter = 0
while not all([x or y for x, y in zip(terminateds, truncateds)]):
counter += 1
_, _, terminateds, truncateds, _ = sync_env.step(
async_env.action_space.sample()
)
assert counter == 20
assert all(terminateds == [False, True, True])
assert all(truncateds == [True, True, False])

View File

@@ -0,0 +1,88 @@
import numpy as np
import pytest
import gym
from gym.spaces import Discrete
from gym.vector import AsyncVectorEnv, SyncVectorEnv
class OldStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def reset(self):
return 0
def step(self, action):
obs = self.observation_space.sample()
rew = 0
done = False
info = {}
return obs, rew, done, info
class NewStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def reset(self):
return 0
def step(self, action):
obs = self.observation_space.sample()
rew = 0
terminated = False
truncated = False
info = {}
return obs, rew, terminated, truncated, info
@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv])
def test_vector_step_compatibility_new_env(VecEnv):
envs = [
OldStepEnv(),
NewStepEnv(),
]
vec_env = VecEnv([lambda: env for env in envs])
vec_env.reset()
step_returns = vec_env.step([0, 0])
assert len(step_returns) == 4
_, _, dones, _ = step_returns
assert dones.dtype == np.bool_
vec_env.close()
vec_env = VecEnv([lambda: env for env in envs], new_step_api=True)
vec_env.reset()
step_returns = vec_env.step([0, 0])
assert len(step_returns) == 5
_, _, terminateds, truncateds, _ = step_returns
assert terminateds.dtype == np.bool_
assert truncateds.dtype == np.bool_
vec_env.close()
@pytest.mark.parametrize("async_bool", [True, False])
def test_vector_step_compatibility_existing(async_bool):
env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool)
env.reset()
step_returns = env.step(env.action_space.sample())
assert len(step_returns) == 4
_, _, dones, _ = step_returns
assert dones.dtype == np.bool_
env.close()
env = gym.vector.make(
"CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True
)
env.reset()
step_returns = env.step(env.action_space.sample())
assert len(step_returns) == 5
_, _, terminateds, truncateds, _ = step_returns
assert terminateds.dtype == np.bool_
assert truncateds.dtype == np.bool_
env.close()

View File

@@ -36,10 +36,10 @@ def test_vector_env_equal(shared_memory):
# fmt: on
if any(sync_dones):
assert "terminal_observation" in async_infos
assert "_terminal_observation" in async_infos
assert "terminal_observation" in sync_infos
assert "_terminal_observation" in sync_infos
assert "final_observation" in async_infos
assert "_final_observation" in async_infos
assert "final_observation" in sync_infos
assert "_final_observation" in sync_infos
assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards)

View File

@@ -22,18 +22,18 @@ def test_vector_env_info(asynchronous):
action = env.action_space.sample()
_, _, dones, infos = env.step(action)
if any(dones):
assert len(infos["terminal_observation"]) == NUM_ENVS
assert len(infos["_terminal_observation"]) == NUM_ENVS
assert len(infos["final_observation"]) == NUM_ENVS
assert len(infos["_final_observation"]) == NUM_ENVS
assert isinstance(infos["terminal_observation"], np.ndarray)
assert isinstance(infos["_terminal_observation"], np.ndarray)
assert isinstance(infos["final_observation"], np.ndarray)
assert isinstance(infos["_final_observation"], np.ndarray)
for i, done in enumerate(dones):
if done:
assert infos["_terminal_observation"][i]
assert infos["_final_observation"][i]
else:
assert not infos["_terminal_observation"][i]
assert infos["terminal_observation"][i] is None
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
@@ -49,8 +49,8 @@ def test_vector_env_info_concurrent_termination(concurrent_ends):
for i, done in enumerate(dones):
if i < concurrent_ends:
assert done
assert infos["_terminal_observation"][i]
assert infos["_final_observation"][i]
else:
assert not infos["_terminal_observation"][i]
assert infos["terminal_observation"][i] is None
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None
return

View File

@@ -136,8 +136,8 @@ def test_autoreset_wrapper_autoreset():
assert reward == 1
assert info == {
"count": 0,
"terminal_observation": np.array([3]),
"terminal_info": {"count": 3},
"final_observation": np.array([3]),
"final_info": {"count": 3},
}
obs, reward, done, info = env.step(action)

View File

@@ -0,0 +1,77 @@
import pytest
import gym
from gym.spaces import Discrete
from gym.wrappers import StepAPICompatibility
class OldStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def step(self, action):
obs = self.observation_space.sample()
rew = 0
done = False
info = {}
return obs, rew, done, info
class NewStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def step(self, action):
obs = self.observation_space.sample()
rew = 0
terminated = False
truncated = False
info = {}
return obs, rew, terminated, truncated, info
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
def test_step_compatibility_to_new_api(env):
env = StepAPICompatibility(env(), True)
step_returns = env.step(0)
_, _, terminated, truncated, _ = step_returns
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
@pytest.mark.parametrize("new_step_api", [None, False])
def test_step_compatibility_to_old_api(env, new_step_api):
if new_step_api is None:
env = StepAPICompatibility(env()) # default behavior is to retain old API
else:
env = StepAPICompatibility(env(), new_step_api)
step_returns = env.step(0)
assert len(step_returns) == 4
_, _, done, _ = step_returns
assert isinstance(done, bool)
@pytest.mark.parametrize("new_step_api", [None, True, False])
def test_step_compatibility_in_make(new_step_api):
if new_step_api is None:
with pytest.warns(
DeprecationWarning, match="Initializing environment in old step API"
):
env = gym.make("CartPole-v1")
else:
env = gym.make("CartPole-v1", new_step_api=new_step_api)
env.reset()
step_returns = env.step(0)
if new_step_api:
assert len(step_returns) == 5
_, _, terminated, truncated, _ = step_returns
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
else:
assert len(step_returns) == 4
_, _, done, _ = step_returns
assert isinstance(done, bool)

View File

@@ -32,9 +32,9 @@ def test_info_to_list():
_, _, dones, list_info = wrapped_env.step(action)
for i, done in enumerate(dones):
if done:
assert "terminal_observation" in list_info[i]
assert "final_observation" in list_info[i]
else:
assert "terminal_observation" not in list_info[i]
assert "final_observation" not in list_info[i]
def test_info_to_list_statistics():