mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 14:26:33 +00:00
New Step API with terminated, truncated bools instead of done (#2752)
This commit is contained in:
66
gym/core.py
66
gym/core.py
@@ -130,11 +130,16 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
|
||||
def np_random(self, value: RandomNumberGenerator):
|
||||
self._np_random = value
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> Union[
|
||||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
|
||||
]:
|
||||
"""Run one timestep of the environment's dynamics.
|
||||
|
||||
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
|
||||
Accepts an action and returns a tuple `(observation, reward, done, info)`.
|
||||
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple
|
||||
(observation, reward, done, info). The latter is deprecated and will be removed in future versions.
|
||||
|
||||
Args:
|
||||
action (ActType): an action provided by the agent
|
||||
@@ -143,14 +148,21 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
|
||||
observation (object): this will be an element of the environment's :attr:`observation_space`.
|
||||
This may, for instance, be a numpy array containing the positions and velocities of certain objects.
|
||||
reward (float): The amount of reward returned as a result of taking the action.
|
||||
terminated (bool): whether a `terminal state` (as defined under the MDP of the task) is reached.
|
||||
In this case further step() calls could return undefined results.
|
||||
truncated (bool): whether a truncation condition outside the scope of the MDP is satisfied.
|
||||
Typically a timelimit, but could also be used to indicate agent physically going out of bounds.
|
||||
Can be used to end the episode prematurely before a `terminal state` is reached.
|
||||
info (dictionary): `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
|
||||
This might, for instance, contain: metrics that describe the agent's performance state, variables that are
|
||||
hidden from observations, or individual reward terms that are combined to produce the total reward.
|
||||
It also can contain information that distinguishes truncation and termination, however this is deprecated in favour
|
||||
of returning two booleans, and will be removed in a future version.
|
||||
|
||||
(deprecated)
|
||||
done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results.
|
||||
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully,
|
||||
a certain timelimit was exceeded, or the physics simulation has entered an invalid state.
|
||||
info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal.
|
||||
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
|
||||
This might, for instance, contain: metrics that describe the agent's performance state, variables that are
|
||||
hidden from observations, information that distinguishes truncation and termination or individual reward terms
|
||||
that are combined to produce the total reward
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -298,11 +310,12 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env):
|
||||
def __init__(self, env: Env, new_step_api: bool = False):
|
||||
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
new_step_api: Whether the wrapper's step method will output in new or old step API
|
||||
"""
|
||||
self.env = env
|
||||
|
||||
@@ -310,6 +323,13 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
self._observation_space: Optional[spaces.Space] = None
|
||||
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
|
||||
self._metadata: Optional[dict] = None
|
||||
self.new_step_api = new_step_api
|
||||
|
||||
if not self.new_step_api:
|
||||
deprecation(
|
||||
"Initializing wrapper in old step API which returns one bool instead of two. "
|
||||
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
|
||||
)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
||||
@@ -391,9 +411,17 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
|
||||
)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> Union[
|
||||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
|
||||
]:
|
||||
"""Steps through the environment with action."""
|
||||
return self.env.step(action)
|
||||
from gym.utils.step_api_compatibility import ( # avoid circular import
|
||||
step_api_compatibility,
|
||||
)
|
||||
|
||||
return step_api_compatibility(self.env.step(action), self.new_step_api)
|
||||
|
||||
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
"""Resets the environment with kwargs."""
|
||||
@@ -463,8 +491,13 @@ class ObservationWrapper(Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return self.observation(observation), reward, done, info
|
||||
step_returns = self.env.step(action)
|
||||
if len(step_returns) == 5:
|
||||
observation, reward, terminated, truncated, info = step_returns
|
||||
return self.observation(observation), reward, terminated, truncated, info
|
||||
else:
|
||||
observation, reward, done, info = step_returns
|
||||
return self.observation(observation), reward, done, info
|
||||
|
||||
def observation(self, observation):
|
||||
"""Returns a modified observation."""
|
||||
@@ -497,8 +530,13 @@ class RewardWrapper(Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return observation, self.reward(reward), done, info
|
||||
step_returns = self.env.step(action)
|
||||
if len(step_returns) == 5:
|
||||
observation, reward, terminated, truncated, info = step_returns
|
||||
return observation, self.reward(reward), terminated, truncated, info
|
||||
else:
|
||||
observation, reward, done, info = step_returns
|
||||
return observation, self.reward(reward), done, info
|
||||
|
||||
def reward(self, reward):
|
||||
"""Returns a modified ``reward``."""
|
||||
|
@@ -599,15 +599,14 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
reward -= 0.00035 * MOTORS_TORQUE * np.clip(np.abs(a), 0, 1)
|
||||
# normalized to about -50.0 using heuristic, more optimal agent should spend less
|
||||
|
||||
done = False
|
||||
terminated = False
|
||||
if self.game_over or pos[0] < 0:
|
||||
reward = -100
|
||||
done = True
|
||||
terminated = True
|
||||
if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP:
|
||||
done = True
|
||||
|
||||
terminated = True
|
||||
self.renderer.render_step()
|
||||
return np.array(state, dtype=np.float32), reward, done, {}
|
||||
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def render(self, mode: str = "human"):
|
||||
if self.render_mode is not None:
|
||||
@@ -789,9 +788,9 @@ if __name__ == "__main__":
|
||||
SUPPORT_KNEE_ANGLE = +0.1
|
||||
supporting_knee_angle = SUPPORT_KNEE_ANGLE
|
||||
while True:
|
||||
s, r, done, info = env.step(a)
|
||||
s, r, terminated, truncated, info = env.step(a)
|
||||
total_reward += r
|
||||
if steps % 20 == 0 or done:
|
||||
if steps % 20 == 0 or terminated or truncated:
|
||||
print("\naction " + str([f"{x:+0.2f}" for x in a]))
|
||||
print(f"step {steps} total_reward {total_reward:+0.2f}")
|
||||
print("hull " + str([f"{x:+0.2f}" for x in s[0:4]]))
|
||||
@@ -854,5 +853,5 @@ if __name__ == "__main__":
|
||||
a[3] = knee_todo[1]
|
||||
a = np.clip(0.5 * a, -1.0, 1.0)
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
break
|
||||
|
@@ -526,8 +526,8 @@ class CarRacing(gym.Env, EzPickle):
|
||||
self.state = self._render("single_state_pixels")
|
||||
|
||||
step_reward = 0
|
||||
done = False
|
||||
info = {}
|
||||
terminated = False
|
||||
truncated = False
|
||||
if action is not None: # First step without action, called from reset()
|
||||
self.reward -= 0.1
|
||||
# We actually don't want to count fuel spent, we want car to be faster.
|
||||
@@ -536,18 +536,17 @@ class CarRacing(gym.Env, EzPickle):
|
||||
step_reward = self.reward - self.prev_reward
|
||||
self.prev_reward = self.reward
|
||||
if self.tile_visited_count == len(self.track) or self.new_lap:
|
||||
done = True
|
||||
# Termination due to finishing lap
|
||||
# Truncation due to finishing lap
|
||||
# This should not be treated as a failure
|
||||
# but like a timeout
|
||||
info["TimeLimit.truncated"] = True
|
||||
truncated = True
|
||||
x, y = self.car.hull.position
|
||||
if abs(x) > PLAYFIELD or abs(y) > PLAYFIELD:
|
||||
done = True
|
||||
terminated = True
|
||||
step_reward = -100
|
||||
|
||||
self.renderer.render_step()
|
||||
return self.state, step_reward, done, info
|
||||
return self.state, step_reward, terminated, truncated, {}
|
||||
|
||||
def render(self, mode: str = "human"):
|
||||
if self.render_mode is not None:
|
||||
@@ -811,13 +810,13 @@ if __name__ == "__main__":
|
||||
restart = False
|
||||
while True:
|
||||
register_input()
|
||||
s, r, done, info = env.step(a)
|
||||
s, r, terminated, truncated, info = env.step(a)
|
||||
total_reward += r
|
||||
if steps % 200 == 0 or done:
|
||||
if steps % 200 == 0 or terminated or truncated:
|
||||
print("\naction " + str([f"{x:+0.2f}" for x in a]))
|
||||
print(f"step {steps} total_reward {total_reward:+0.2f}")
|
||||
steps += 1
|
||||
isopen = env.render()
|
||||
if done or restart or isopen is False:
|
||||
if terminated or truncated or restart or isopen is False:
|
||||
break
|
||||
env.close()
|
||||
|
@@ -11,6 +11,7 @@ from gym import error, spaces
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils import EzPickle, colorize
|
||||
from gym.utils.renderer import Renderer
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
try:
|
||||
import Box2D
|
||||
@@ -577,16 +578,15 @@ class LunarLander(gym.Env, EzPickle):
|
||||
) # less fuel spent is better, about -30 for heuristic landing
|
||||
reward -= s_power * 0.03
|
||||
|
||||
done = False
|
||||
terminated = False
|
||||
if self.game_over or abs(state[0]) >= 1.0:
|
||||
done = True
|
||||
terminated = True
|
||||
reward = -100
|
||||
if not self.lander.awake:
|
||||
done = True
|
||||
terminated = True
|
||||
reward = +100
|
||||
|
||||
self.renderer.render_step()
|
||||
return np.array(state, dtype=np.float32), reward, done, {}
|
||||
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def render(self, mode="human"):
|
||||
if self.render_mode is not None:
|
||||
@@ -771,7 +771,7 @@ def demo_heuristic_lander(env, seed=None, render=False):
|
||||
s = env.reset(seed=seed)
|
||||
while True:
|
||||
a = heuristic(env, s)
|
||||
s, r, done, info = env.step(a)
|
||||
s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True)
|
||||
total_reward += r
|
||||
|
||||
if render:
|
||||
@@ -779,11 +779,11 @@ def demo_heuristic_lander(env, seed=None, render=False):
|
||||
if still_open is False:
|
||||
break
|
||||
|
||||
if steps % 20 == 0 or done:
|
||||
if steps % 20 == 0 or terminated or truncated:
|
||||
print("observations:", " ".join([f"{x:+0.2f}" for x in s]))
|
||||
print(f"step {steps} total_reward {total_reward:+0.2f}")
|
||||
steps += 1
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
break
|
||||
if render:
|
||||
env.close()
|
||||
|
@@ -86,12 +86,12 @@ class AcrobotEnv(core.Env):
|
||||
Each parameter in the underlying state (`theta1`, `theta2`, and the two angular velocities) is initialized
|
||||
uniformly between -0.1 and 0.1. This means both links are pointing downwards with some initial stochasticity.
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
|
||||
The episode terminates if one of the following occurs:
|
||||
1. The free end reaches the target height, which is constructed as:
|
||||
The episode ends if one of the following occurs:
|
||||
1. Termination: The free end reaches the target height, which is constructed as:
|
||||
`-cos(theta1) - cos(theta2 + theta1) > 1.0`
|
||||
2. Episode length is greater than 500 (200 for v0)
|
||||
2. Truncation: Episode length is greater than 500 (200 for v0)
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -226,11 +226,11 @@ class AcrobotEnv(core.Env):
|
||||
ns[2] = bound(ns[2], -self.MAX_VEL_1, self.MAX_VEL_1)
|
||||
ns[3] = bound(ns[3], -self.MAX_VEL_2, self.MAX_VEL_2)
|
||||
self.state = ns
|
||||
terminal = self._terminal()
|
||||
reward = -1.0 if not terminal else 0.0
|
||||
terminated = self._terminal()
|
||||
reward = -1.0 if not terminated else 0.0
|
||||
|
||||
self.renderer.render_step()
|
||||
return self._get_ob(), reward, terminal, {}
|
||||
return (self._get_ob(), reward, terminated, False, {})
|
||||
|
||||
def _get_ob(self):
|
||||
s = self.state
|
||||
|
@@ -65,12 +65,13 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
|
||||
All observations are assigned a uniformly random value in `(-0.05, 0.05)`
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
|
||||
The episode terminates if any one of the following occurs:
|
||||
1. Pole Angle is greater than ±12°
|
||||
2. Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
|
||||
3. Episode length is greater than 500 (200 for v0)
|
||||
The episode ends if any one of the following occurs:
|
||||
|
||||
1. Termination: Pole Angle is greater than ±12°
|
||||
2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
|
||||
3. Truncation: Episode length is greater than 500 (200 for v0)
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -126,7 +127,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
self.isopen = True
|
||||
self.state = None
|
||||
|
||||
self.steps_beyond_done = None
|
||||
self.steps_beyond_terminated = None
|
||||
|
||||
def step(self, action):
|
||||
err_msg = f"{action!r} ({type(action)}) invalid"
|
||||
@@ -160,32 +161,32 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
|
||||
self.state = (x, x_dot, theta, theta_dot)
|
||||
|
||||
done = bool(
|
||||
terminated = bool(
|
||||
x < -self.x_threshold
|
||||
or x > self.x_threshold
|
||||
or theta < -self.theta_threshold_radians
|
||||
or theta > self.theta_threshold_radians
|
||||
)
|
||||
|
||||
if not done:
|
||||
if not terminated:
|
||||
reward = 1.0
|
||||
elif self.steps_beyond_done is None:
|
||||
elif self.steps_beyond_terminated is None:
|
||||
# Pole just fell!
|
||||
self.steps_beyond_done = 0
|
||||
self.steps_beyond_terminated = 0
|
||||
reward = 1.0
|
||||
else:
|
||||
if self.steps_beyond_done == 0:
|
||||
if self.steps_beyond_terminated == 0:
|
||||
logger.warn(
|
||||
"You are calling 'step()' even though this "
|
||||
"environment has already returned done = True. You "
|
||||
"should always call 'reset()' once you receive 'done = "
|
||||
"environment has already returned terminated = True. You "
|
||||
"should always call 'reset()' once you receive 'terminated = "
|
||||
"True' -- any further steps are undefined behavior."
|
||||
)
|
||||
self.steps_beyond_done += 1
|
||||
self.steps_beyond_terminated += 1
|
||||
reward = 0.0
|
||||
|
||||
self.renderer.render_step()
|
||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
@@ -201,7 +202,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
options, -0.05, 0.05 # default low
|
||||
) # default high
|
||||
self.state = self.np_random.uniform(low=low, high=high, size=(4,))
|
||||
self.steps_beyond_done = None
|
||||
self.steps_beyond_terminated = None
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
|
@@ -84,11 +84,11 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`.
|
||||
The starting velocity of the car is always assigned to 0.
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
|
||||
The episode terminates if either of the following happens:
|
||||
1. The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill)
|
||||
2. The length of the episode is 999.
|
||||
The episode ends if either of the following happens:
|
||||
1. Termination: The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill)
|
||||
2. Truncation: The length of the episode is 999.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -161,17 +161,18 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
velocity = 0
|
||||
|
||||
# Convert a possible numpy bool to a Python bool.
|
||||
done = bool(position >= self.goal_position and velocity >= self.goal_velocity)
|
||||
terminated = bool(
|
||||
position >= self.goal_position and velocity >= self.goal_velocity
|
||||
)
|
||||
|
||||
reward = 0
|
||||
if done:
|
||||
if terminated:
|
||||
reward = 100.0
|
||||
reward -= math.pow(action[0], 2) * 0.1
|
||||
|
||||
self.state = np.array([position, velocity], dtype=np.float32)
|
||||
|
||||
self.renderer.render_step()
|
||||
return self.state, reward, done, {}
|
||||
return self.state, reward, terminated, False, {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -78,11 +78,11 @@ class MountainCarEnv(gym.Env):
|
||||
The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*.
|
||||
The starting velocity of the car is always assigned to 0.
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
|
||||
The episode terminates if either of the following happens:
|
||||
1. The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill)
|
||||
2. The length of the episode is 200.
|
||||
The episode ends if either of the following happens:
|
||||
1. Termination: The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill)
|
||||
2. Truncation: The length of the episode is 200.
|
||||
|
||||
|
||||
### Arguments
|
||||
@@ -139,13 +139,14 @@ class MountainCarEnv(gym.Env):
|
||||
if position == self.min_position and velocity < 0:
|
||||
velocity = 0
|
||||
|
||||
done = bool(position >= self.goal_position and velocity >= self.goal_velocity)
|
||||
terminated = bool(
|
||||
position >= self.goal_position and velocity >= self.goal_velocity
|
||||
)
|
||||
reward = -1.0
|
||||
|
||||
self.state = (position, velocity)
|
||||
|
||||
self.renderer.render_step()
|
||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -68,9 +68,9 @@ class PendulumEnv(gym.Env):
|
||||
|
||||
The starting state is a random angle in *[-pi, pi]* and a random angular velocity in *[-1,1]*.
|
||||
|
||||
### Episode Termination
|
||||
### Episode Truncation
|
||||
|
||||
The episode terminates at 200 time steps.
|
||||
The episode truncates at 200 time steps.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -136,7 +136,7 @@ class PendulumEnv(gym.Env):
|
||||
|
||||
self.state = np.array([newth, newthdot])
|
||||
self.renderer.render_step()
|
||||
return self._get_obs(), -costs, False, {}
|
||||
return self._get_obs(), -costs, False, False, {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -41,13 +41,16 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
survive_reward = 1.0
|
||||
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
|
||||
state = self.state_vector()
|
||||
notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
|
||||
done = not notdone
|
||||
not_terminated = (
|
||||
np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
|
||||
)
|
||||
terminated = not not_terminated
|
||||
ob = self._get_obs()
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
done,
|
||||
terminated,
|
||||
False,
|
||||
dict(
|
||||
reward_forward=forward_reward,
|
||||
reward_ctrl=-ctrl_cost,
|
||||
|
@@ -97,9 +97,9 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def step(self, action):
|
||||
xy_position_before = self.get_body_com("torso")[:2].copy()
|
||||
@@ -121,7 +121,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.renderer.render_step()
|
||||
|
||||
reward = rewards - costs
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
observation = self._get_obs()
|
||||
info = {
|
||||
"reward_forward": forward_reward,
|
||||
@@ -136,7 +136,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"forward_reward": forward_reward,
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.sim.data.qpos.flat.copy()
|
||||
|
@@ -124,19 +124,19 @@ class AntEnv(MujocoEnv, utils.EzPickle):
|
||||
to be slightly high, thereby indicating a standing up ant. The initial orientation
|
||||
is designed to make it face forward as well.
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
The ant is said to be unhealthy if any of the following happens:
|
||||
|
||||
1. Any of the state space values is no longer finite
|
||||
2. The z-coordinate of the torso is **not** in the closed interval given by `healthy_z_range` (defaults to [0.2, 1.0])
|
||||
|
||||
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
|
||||
the episode terminates when any of the following happens:
|
||||
the episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches a 1000 timesteps
|
||||
2. The ant is unhealthy
|
||||
1. Termination: The episode duration reaches a 1000 timesteps
|
||||
2. Truncation: The ant is unhealthy
|
||||
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -263,9 +263,9 @@ class AntEnv(MujocoEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def step(self, action):
|
||||
xy_position_before = self.get_body_com("torso")[:2].copy()
|
||||
@@ -282,7 +282,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
costs = ctrl_cost = self.control_cost(action)
|
||||
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
observation = self._get_obs()
|
||||
info = {
|
||||
"reward_forward": forward_reward,
|
||||
@@ -303,7 +303,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
|
||||
reward = rewards - costs
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.data.qpos.flat.copy()
|
||||
|
@@ -35,8 +35,14 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
reward_ctrl = -0.1 * np.square(action).sum()
|
||||
reward_run = (xposafter - xposbefore) / self.dt
|
||||
reward = reward_ctrl + reward_run
|
||||
done = False
|
||||
return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl)
|
||||
terminated = False
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
terminated,
|
||||
False,
|
||||
dict(reward_run=reward_run, reward_ctrl=reward_ctrl),
|
||||
)
|
||||
|
||||
def _get_obs(self):
|
||||
return np.concatenate(
|
||||
|
@@ -75,7 +75,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = forward_reward - ctrl_cost
|
||||
done = False
|
||||
terminated = False
|
||||
info = {
|
||||
"x_position": x_position_after,
|
||||
"x_velocity": x_velocity,
|
||||
@@ -83,7 +83,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"reward_ctrl": -ctrl_cost,
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.sim.data.qpos.flat.copy()
|
||||
|
@@ -98,8 +98,8 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
|
||||
normal noise with a mean of 0 and standard deviation of `reset_noise_scale` is added to the
|
||||
initial velocity values of all zeros.
|
||||
|
||||
### Episode Termination
|
||||
The episode terminates when the episode length is greater than 1000.
|
||||
### Episode End
|
||||
The episode truncates when the episode length is greater than 1000.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -192,7 +192,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = forward_reward - ctrl_cost
|
||||
done = False
|
||||
terminated = False
|
||||
info = {
|
||||
"x_position": x_position_after,
|
||||
"x_velocity": x_velocity,
|
||||
@@ -201,7 +201,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.data.qpos.flat.copy()
|
||||
|
@@ -36,14 +36,14 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
reward += alive_bonus
|
||||
reward -= 1e-3 * np.square(a).sum()
|
||||
s = self.state_vector()
|
||||
done = not (
|
||||
terminated = not (
|
||||
np.isfinite(s).all()
|
||||
and (np.abs(s[2:]) < 100).all()
|
||||
and (height > 0.7)
|
||||
and (abs(ang) < 0.2)
|
||||
)
|
||||
ob = self._get_obs()
|
||||
return ob, reward, done, {}
|
||||
return ob, reward, terminated, False, {}
|
||||
|
||||
def _get_obs(self):
|
||||
return np.concatenate(
|
||||
|
@@ -101,9 +101,9 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.sim.data.qpos.flat.copy()
|
||||
@@ -133,13 +133,13 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - costs
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
info = {
|
||||
"x_position": x_position_after,
|
||||
"x_velocity": x_velocity,
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
noise_low = -self._reset_noise_scale
|
||||
|
@@ -87,7 +87,7 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
|
||||
(0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise
|
||||
in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity.
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
The hopper is said to be unhealthy if any of the following happens:
|
||||
|
||||
1. An element of `observation[1:]` (if `exclude_current_positions_from_observation=True`, else `observation[2:]`) is no longer contained in the closed interval specified by the argument `healthy_state_range`
|
||||
@@ -95,12 +95,12 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
|
||||
3. The angle (`observation[1]` if `exclude_current_positions_from_observation=True`, else `observation[2]`) is no longer contained in the closed interval specified by the argument `healthy_angle_range`
|
||||
|
||||
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
|
||||
the episode terminates when any of the following happens:
|
||||
the episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches a 1000 timesteps
|
||||
2. The hopper is unhealthy
|
||||
1. Truncation: The episode duration reaches a 1000 timesteps
|
||||
2. Termination: The hopper is unhealthy
|
||||
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -223,9 +223,9 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.data.qpos.flat.copy()
|
||||
@@ -253,14 +253,14 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - costs
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
info = {
|
||||
"x_position": x_position_after,
|
||||
"x_velocity": x_velocity,
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
noise_low = -self._reset_noise_scale
|
||||
|
@@ -60,11 +60,12 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
quad_impact_cost = min(quad_impact_cost, 10)
|
||||
reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus
|
||||
qpos = self.sim.data.qpos
|
||||
done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
|
||||
terminated = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
|
||||
return (
|
||||
self._get_obs(),
|
||||
reward,
|
||||
done,
|
||||
terminated,
|
||||
False,
|
||||
dict(
|
||||
reward_linvel=lin_vel_cost,
|
||||
reward_quadctrl=-quad_ctrl_cost,
|
||||
|
@@ -99,9 +99,9 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = (not self.is_healthy) if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.sim.data.qpos.flat.copy()
|
||||
@@ -148,7 +148,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - costs
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
info = {
|
||||
"reward_linvel": forward_reward,
|
||||
"reward_quadctrl": -ctrl_cost,
|
||||
@@ -162,7 +162,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"forward_reward": forward_reward,
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
noise_low = -self._reset_noise_scale
|
||||
|
@@ -165,18 +165,17 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
|
||||
selected to be high, thereby indicating a standing up humanoid. The initial
|
||||
orientation is designed to make it face forward as well.
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
The humanoid is said to be unhealthy if the z-position of the torso is no longer contained in the
|
||||
closed interval specified by the argument `healthy_z_range`.
|
||||
|
||||
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
|
||||
the episode terminates when any of the following happens:
|
||||
the episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches a 1000 timesteps
|
||||
3. The humanoid is unhealthy
|
||||
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
|
||||
1. Truncation: The episode duration reaches a 1000 timesteps
|
||||
3. Termination: The humanoid is unhealthy
|
||||
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -281,9 +280,9 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = (not self.is_healthy) if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.data.qpos.flat.copy()
|
||||
@@ -326,7 +325,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - ctrl_cost
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
info = {
|
||||
"reward_linvel": forward_reward,
|
||||
"reward_quadctrl": -ctrl_cost,
|
||||
@@ -340,7 +339,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
noise_low = -self._reset_noise_scale
|
||||
|
@@ -56,11 +56,11 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
done = bool(False)
|
||||
return (
|
||||
self._get_obs(),
|
||||
reward,
|
||||
done,
|
||||
False,
|
||||
False,
|
||||
dict(
|
||||
reward_linup=uph_cost,
|
||||
reward_quadctrl=-quad_ctrl_cost,
|
||||
|
@@ -151,11 +151,11 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
|
||||
to be low, thereby indicating a laying down humanoid. The initial orientation is
|
||||
designed to make it face forward as well.
|
||||
|
||||
### Episode Termination
|
||||
The episode terminates when any of the following happens:
|
||||
### Episode End
|
||||
The episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches a 1000 timesteps
|
||||
2. Any of the state space values is no longer finite
|
||||
1. Truncation: The episode duration reaches a 1000 timesteps
|
||||
2. Termination: Any of the state space values is no longer finite
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -228,11 +228,11 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
done = bool(False)
|
||||
return (
|
||||
self._get_obs(),
|
||||
reward,
|
||||
done,
|
||||
False,
|
||||
False,
|
||||
dict(
|
||||
reward_linup=uph_cost,
|
||||
reward_quadctrl=-quad_ctrl_cost,
|
||||
|
@@ -40,8 +40,8 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2
|
||||
alive_bonus = 10
|
||||
r = alive_bonus - dist_penalty - vel_penalty
|
||||
done = bool(y <= 1)
|
||||
return ob, r, done, {}
|
||||
terminated = bool(y <= 1)
|
||||
return ob, r, terminated, False, {}
|
||||
|
||||
def _get_obs(self):
|
||||
return np.concatenate(
|
||||
|
@@ -85,12 +85,12 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
of [-0.1, 0.1] added to the positional values (cart position and pole angles) and standard
|
||||
normal force with a standard deviation of 0.1 added to the velocity values for stochasticity.
|
||||
|
||||
### Episode Termination
|
||||
The episode terminates when any of the following happens:
|
||||
### Episode End
|
||||
The episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches 1000 timesteps.
|
||||
2. Any of the state space values is no longer finite.
|
||||
3. The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other).
|
||||
1.Truncation: The episode duration reaches 1000 timesteps.
|
||||
2.Termination: Any of the state space values is no longer finite.
|
||||
3.Termination: The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other).
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -143,11 +143,9 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2
|
||||
alive_bonus = 10
|
||||
r = alive_bonus - dist_penalty - vel_penalty
|
||||
done = bool(y <= 1)
|
||||
|
||||
terminated = bool(y <= 1)
|
||||
self.renderer.render_step()
|
||||
|
||||
return ob, r, done, {}
|
||||
return ob, r, terminated, False, {}
|
||||
|
||||
def _get_obs(self):
|
||||
return np.concatenate(
|
||||
|
@@ -35,9 +35,8 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2)
|
||||
done = not notdone
|
||||
return ob, reward, done, {}
|
||||
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
|
||||
return ob, reward, terminated, False, {}
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
|
@@ -56,12 +56,12 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
(0.0, 0.0, 0.0, 0.0) with a uniform noise in the range
|
||||
of [-0.01, 0.01] added to the values for stochasticity.
|
||||
|
||||
### Episode Termination
|
||||
The episode terminates when any of the following happens:
|
||||
### Episode End
|
||||
The episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches 1000 timesteps.
|
||||
2. Any of the state space values is no longer finite.
|
||||
3. The absolute value of the vertical angle between the pole and the cart is greater than 0.2 radians.
|
||||
1. Truncation: The episode duration reaches 1000 timesteps.
|
||||
2. Termination: Any of the state space values is no longer finite.
|
||||
3. Termination: The absolutely value of the vertical angle between the pole and the cart is greater than 0.2 radian.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -109,12 +109,9 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
reward = 1.0
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2)
|
||||
done = not notdone
|
||||
|
||||
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
|
||||
self.renderer.render_step()
|
||||
|
||||
return ob, reward, done, {}
|
||||
return ob, reward, terminated, False, {}
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
|
@@ -38,8 +38,13 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
done = False
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
False,
|
||||
False,
|
||||
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
|
||||
)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
|
@@ -99,12 +99,12 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
The default framerate is 5 with each frame lasting for 0.01, giving rise to a *dt = 5 * 0.01 = 0.05*
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
|
||||
The episode terminates when any of the following happens:
|
||||
The episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches a 100 timesteps.
|
||||
2. Any of the state space values is no longer finite.
|
||||
1. Truncation: The episode duration reaches a 100 timesteps.
|
||||
2. Termination: Any of the state space values is no longer finite.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -157,11 +157,14 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
done = False
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
False,
|
||||
False,
|
||||
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
|
||||
)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
|
@@ -34,8 +34,13 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
done = False
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
False,
|
||||
False,
|
||||
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
|
||||
)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
|
@@ -89,12 +89,12 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
element ("fingertip" - "target") is calculated at the end once everything
|
||||
is set. The default setting has a framerate of 2 and a *dt = 2 * 0.01 = 0.02*
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
|
||||
The episode terminates when any of the following happens:
|
||||
The episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps)
|
||||
2. Any of the state space values is no longer finite.
|
||||
1. Truncation: The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps)
|
||||
2. Termination: Any of the state space values is no longer finite.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -143,11 +143,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
reward = reward_dist + reward_ctrl
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
done = False
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
False,
|
||||
False,
|
||||
dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
|
||||
)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
|
@@ -36,7 +36,13 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
|
||||
reward = reward_fwd + reward_ctrl
|
||||
ob = self._get_obs()
|
||||
return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
False,
|
||||
False,
|
||||
dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl),
|
||||
)
|
||||
|
||||
def _get_obs(self):
|
||||
qpos = self.sim.data.qpos
|
||||
|
@@ -73,7 +73,6 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = forward_reward - ctrl_cost
|
||||
done = False
|
||||
info = {
|
||||
"reward_fwd": forward_reward,
|
||||
"reward_ctrl": -ctrl_cost,
|
||||
@@ -85,7 +84,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"forward_reward": forward_reward,
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, False, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.sim.data.qpos.flat.copy()
|
||||
|
@@ -89,8 +89,8 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
|
||||
### Starting State
|
||||
All observations start in state (0,0,0,0,0,0,0,0) with a Uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] is added to the initial state for stochasticity.
|
||||
|
||||
### Episode Termination
|
||||
The episode terminates when the episode length is greater than 1000.
|
||||
### Episode End
|
||||
The episode truncates when the episode length is greater than 1000.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -183,7 +183,6 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = forward_reward - ctrl_cost
|
||||
done = False
|
||||
info = {
|
||||
"reward_fwd": forward_reward,
|
||||
"reward_ctrl": -ctrl_cost,
|
||||
@@ -196,7 +195,7 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, done, info
|
||||
return observation, reward, False, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.data.qpos.flat.copy()
|
||||
|
@@ -35,10 +35,9 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
reward = (posafter - posbefore) / self.dt
|
||||
reward += alive_bonus
|
||||
reward -= 1e-3 * np.square(a).sum()
|
||||
done = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0)
|
||||
terminated = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0)
|
||||
ob = self._get_obs()
|
||||
|
||||
return ob, reward, done, {}
|
||||
return ob, reward, terminated, False, {}
|
||||
|
||||
def _get_obs(self):
|
||||
qpos = self.sim.data.qpos
|
||||
|
@@ -92,9 +92,9 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.sim.data.qpos.flat.copy()
|
||||
@@ -123,13 +123,13 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - costs
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
info = {
|
||||
"x_position": x_position_after,
|
||||
"x_velocity": x_velocity,
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
noise_low = -self._reset_noise_scale
|
||||
|
@@ -92,7 +92,7 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
|
||||
(0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
|
||||
with a uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity.
|
||||
|
||||
### Episode Termination
|
||||
### Episode End
|
||||
The walker is said to be unhealthy if any of the following happens:
|
||||
|
||||
1. Any of the state space values is no longer finite
|
||||
@@ -100,12 +100,12 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
|
||||
3. The absolute value of the angle (`observation[1]` if `exclude_current_positions_from_observation=False`, else `observation[2]`) is ***not*** in the closed interval specified by `healthy_angle_range`
|
||||
|
||||
If `terminate_when_unhealthy=True` is passed during construction (which is the default),
|
||||
the episode terminates when any of the following happens:
|
||||
the episode ends when any of the following happens:
|
||||
|
||||
1. The episode duration reaches a 1000 timesteps
|
||||
2. The walker is unhealthy
|
||||
1. Truncation: The episode duration reaches a 1000 timesteps
|
||||
2. Termination: The walker is unhealthy
|
||||
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded.
|
||||
If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded.
|
||||
|
||||
### Arguments
|
||||
|
||||
@@ -221,9 +221,9 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
|
||||
return is_healthy
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
done = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return done
|
||||
def terminated(self):
|
||||
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
|
||||
return terminated
|
||||
|
||||
def _get_obs(self):
|
||||
position = self.data.qpos.flat.copy()
|
||||
@@ -251,14 +251,14 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - costs
|
||||
done = self.done
|
||||
terminated = self.terminated
|
||||
info = {
|
||||
"x_position": x_position_after,
|
||||
"x_velocity": x_velocity,
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
noise_low = -self._reset_noise_scale
|
||||
|
@@ -23,7 +23,13 @@ from typing import (
|
||||
import numpy as np
|
||||
|
||||
from gym.envs.__relocated__ import internal_env_relocation_map
|
||||
from gym.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit
|
||||
from gym.wrappers import (
|
||||
AutoResetWrapper,
|
||||
HumanRendering,
|
||||
OrderEnforcing,
|
||||
StepAPICompatibility,
|
||||
TimeLimit,
|
||||
)
|
||||
from gym.wrappers.env_checker import PassiveEnvChecker
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
@@ -118,6 +124,7 @@ class EnvSpec:
|
||||
max_episode_steps: Optional[int] = field(default=None)
|
||||
order_enforce: bool = field(default=True)
|
||||
autoreset: bool = field(default=False)
|
||||
new_step_api: bool = field(default=False)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
||||
namespace: Optional[str] = field(init=False)
|
||||
@@ -522,6 +529,7 @@ def make(
|
||||
id: Union[str, EnvSpec],
|
||||
max_episode_steps: Optional[int] = None,
|
||||
autoreset: bool = False,
|
||||
new_step_api: bool = False,
|
||||
disable_env_checker: bool = False,
|
||||
**kwargs,
|
||||
) -> Env:
|
||||
@@ -531,6 +539,7 @@ def make(
|
||||
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
|
||||
disable_env_checker: If to disable the environment checker
|
||||
kwargs: Additional arguments to pass to the environment constructor.
|
||||
|
||||
@@ -644,19 +653,21 @@ def make(
|
||||
if disable_env_checker is False:
|
||||
env = PassiveEnvChecker(env)
|
||||
|
||||
env = StepAPICompatibility(env, new_step_api)
|
||||
|
||||
# Add the order enforcing wrapper
|
||||
if spec_.order_enforce:
|
||||
env = OrderEnforcing(env)
|
||||
|
||||
# Add the time limit wrapper
|
||||
if max_episode_steps is not None:
|
||||
env = TimeLimit(env, max_episode_steps)
|
||||
env = TimeLimit(env, max_episode_steps, new_step_api)
|
||||
elif spec_.max_episode_steps is not None:
|
||||
env = TimeLimit(env, spec_.max_episode_steps)
|
||||
env = TimeLimit(env, spec_.max_episode_steps, new_step_api)
|
||||
|
||||
# Add the autoreset wrapper
|
||||
if autoreset:
|
||||
env = AutoResetWrapper(env)
|
||||
env = AutoResetWrapper(env, new_step_api)
|
||||
|
||||
return env
|
||||
|
||||
|
@@ -137,13 +137,13 @@ class BlackjackEnv(gym.Env):
|
||||
if action: # hit: add a card to players hand and return
|
||||
self.player.append(draw_card(self.np_random))
|
||||
if is_bust(self.player):
|
||||
done = True
|
||||
terminated = True
|
||||
reward = -1.0
|
||||
else:
|
||||
done = False
|
||||
terminated = False
|
||||
reward = 0.0
|
||||
else: # stick: play out the dealers hand, and score
|
||||
done = True
|
||||
terminated = True
|
||||
while sum_hand(self.dealer) < 17:
|
||||
self.dealer.append(draw_card(self.np_random))
|
||||
reward = cmp(score(self.player), score(self.dealer))
|
||||
@@ -158,9 +158,8 @@ class BlackjackEnv(gym.Env):
|
||||
):
|
||||
# Natural gives extra points, but doesn't autowin. Legacy implementation
|
||||
reward = 1.5
|
||||
|
||||
self.renderer.render_step()
|
||||
return self._get_obs(), reward, done, {}
|
||||
return self._get_obs(), reward, terminated, False, {}
|
||||
|
||||
def _get_obs(self):
|
||||
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
|
||||
|
@@ -111,7 +111,7 @@ class CliffWalkingEnv(Env):
|
||||
delta: Change in position for transition
|
||||
|
||||
Returns:
|
||||
Tuple of ``(1.0, new_state, reward, done)``
|
||||
Tuple of ``(1.0, new_state, reward, terminated)``
|
||||
"""
|
||||
new_position = np.array(current) + np.array(delta)
|
||||
new_position = self._limit_coordinates(new_position).astype(int)
|
||||
@@ -120,17 +120,17 @@ class CliffWalkingEnv(Env):
|
||||
return [(1.0, self.start_state_index, -100, False)]
|
||||
|
||||
terminal_state = (self.shape[0] - 1, self.shape[1] - 1)
|
||||
is_done = tuple(new_position) == terminal_state
|
||||
return [(1.0, new_state, -1, is_done)]
|
||||
is_terminated = tuple(new_position) == terminal_state
|
||||
return [(1.0, new_state, -1, is_terminated)]
|
||||
|
||||
def step(self, a):
|
||||
transitions = self.P[self.s][a]
|
||||
i = categorical_sample([t[0] for t in transitions], self.np_random)
|
||||
p, s, r, d = transitions[i]
|
||||
p, s, r, t = transitions[i]
|
||||
self.s = s
|
||||
self.lastaction = a
|
||||
self.renderer.render_step()
|
||||
return (int(s), r, d, {"prob": p})
|
||||
return (int(s), r, t, False, {"prob": p})
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -201,9 +201,9 @@ class FrozenLakeEnv(Env):
|
||||
newrow, newcol = inc(row, col, action)
|
||||
newstate = to_s(newrow, newcol)
|
||||
newletter = desc[newrow, newcol]
|
||||
done = bytes(newletter) in b"GH"
|
||||
terminated = bytes(newletter) in b"GH"
|
||||
reward = float(newletter == b"G")
|
||||
return newstate, reward, done
|
||||
return newstate, reward, terminated
|
||||
|
||||
for row in range(nrow):
|
||||
for col in range(ncol):
|
||||
@@ -242,13 +242,11 @@ class FrozenLakeEnv(Env):
|
||||
def step(self, a):
|
||||
transitions = self.P[self.s][a]
|
||||
i = categorical_sample([t[0] for t in transitions], self.np_random)
|
||||
p, s, r, d = transitions[i]
|
||||
p, s, r, t = transitions[i]
|
||||
self.s = s
|
||||
self.lastaction = a
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
return (int(s), r, d, {"prob": p})
|
||||
return (int(s), r, t, False, {"prob": p})
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -156,7 +156,7 @@ class TaxiEnv(Env):
|
||||
reward = (
|
||||
-1
|
||||
) # default reward when there is no pickup/dropoff
|
||||
done = False
|
||||
terminated = False
|
||||
taxi_loc = (row, col)
|
||||
|
||||
if action == 0:
|
||||
@@ -175,7 +175,7 @@ class TaxiEnv(Env):
|
||||
elif action == 5: # dropoff
|
||||
if (taxi_loc == locs[dest_idx]) and pass_idx == 4:
|
||||
new_pass_idx = dest_idx
|
||||
done = True
|
||||
terminated = True
|
||||
reward = 20
|
||||
elif (taxi_loc in locs) and pass_idx == 4:
|
||||
new_pass_idx = locs.index(taxi_loc)
|
||||
@@ -184,7 +184,9 @@ class TaxiEnv(Env):
|
||||
new_state = self.encode(
|
||||
new_row, new_col, new_pass_idx, dest_idx
|
||||
)
|
||||
self.P[state][action].append((1.0, new_state, reward, done))
|
||||
self.P[state][action].append(
|
||||
(1.0, new_state, reward, terminated)
|
||||
)
|
||||
self.initial_state_distrib /= self.initial_state_distrib.sum()
|
||||
self.action_space = spaces.Discrete(num_actions)
|
||||
self.observation_space = spaces.Discrete(num_states)
|
||||
@@ -254,12 +256,11 @@ class TaxiEnv(Env):
|
||||
def step(self, a):
|
||||
transitions = self.P[self.s][a]
|
||||
i = categorical_sample([t[0] for t in transitions], self.np_random)
|
||||
p, s, r, d = transitions[i]
|
||||
p, s, r, t = transitions[i]
|
||||
self.s = s
|
||||
self.lastaction = a
|
||||
self.renderer.render_step()
|
||||
|
||||
return int(s), r, d, {"prob": p, "action_mask": self.action_mask(s)}
|
||||
return (int(s), r, t, False, {"prob": p, "action_mask": self.action_mask(s)})
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -58,7 +58,7 @@ class ResetNeeded(Error):
|
||||
|
||||
|
||||
class ResetNotAllowed(Error):
|
||||
"""When the monitor is active, raised when the user tries to step an environment that's not yet done."""
|
||||
"""When the monitor is active, raised when the user tries to step an environment that's not yet terminated or truncated."""
|
||||
|
||||
|
||||
class InvalidAction(Error):
|
||||
|
@@ -4,6 +4,7 @@ import inspect
|
||||
import numpy as np
|
||||
|
||||
from gym import error, logger, spaces
|
||||
from gym.logger import deprecation
|
||||
|
||||
|
||||
def _check_box_observation_space(observation_space: spaces.Box):
|
||||
@@ -253,14 +254,24 @@ def passive_env_step_check(env, action):
|
||||
"""A passive check for the environment step, investigating the returning data then returning the data unchanged."""
|
||||
result = env.step(action)
|
||||
if len(result) == 4:
|
||||
deprecation(
|
||||
"Core environment is written in old step API which returns one bool instead of two. "
|
||||
"It is recommended to rewrite the environment with new step API. "
|
||||
)
|
||||
obs, reward, done, info = result
|
||||
|
||||
assert isinstance(done, bool), "The `done` signal must be a boolean"
|
||||
assert isinstance(
|
||||
done, bool
|
||||
), f"The `done` signal is of type `{type(done)}` must be a boolean"
|
||||
elif len(result) == 5:
|
||||
obs, reward, terminated, truncated, info = result
|
||||
|
||||
assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
|
||||
assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
|
||||
assert isinstance(
|
||||
terminated, bool
|
||||
), f"The `terminated` signal is of type `{type(terminated)}`. It must be a boolean"
|
||||
assert isinstance(
|
||||
truncated, bool
|
||||
), f"The `truncated` signal of type `{type(truncated)}`. It must be a boolean."
|
||||
assert (
|
||||
terminated is False or truncated is False
|
||||
), "Only `terminated` or `truncated` can be true, not both."
|
||||
|
@@ -1,4 +1,7 @@
|
||||
"""Utilities of visualising an environment."""
|
||||
|
||||
# TODO: Convert to new step API in 1.0
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -208,6 +211,10 @@ def play(
|
||||
seed: Random seed used when resetting the environment. If None, no seed is used.
|
||||
noop: The action used when no key input has been entered, or the entered key combination is unknown.
|
||||
"""
|
||||
deprecation(
|
||||
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
|
||||
)
|
||||
|
||||
env.reset(seed=seed)
|
||||
|
||||
if keys_to_action is None:
|
||||
|
180
gym/utils/step_api_compatibility.py
Normal file
180
gym/utils/step_api_compatibility.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0."""
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.core import ObsType
|
||||
|
||||
OldStepType = Tuple[
|
||||
Union[ObsType, np.ndarray],
|
||||
Union[float, np.ndarray],
|
||||
Union[bool, np.ndarray],
|
||||
Union[dict, list],
|
||||
]
|
||||
|
||||
NewStepType = Tuple[
|
||||
Union[ObsType, np.ndarray],
|
||||
Union[float, np.ndarray],
|
||||
Union[bool, np.ndarray],
|
||||
Union[bool, np.ndarray],
|
||||
Union[dict, list],
|
||||
]
|
||||
|
||||
|
||||
def step_to_new_api(
|
||||
step_returns: Union[OldStepType, NewStepType], is_vector_env=False
|
||||
) -> NewStepType:
|
||||
"""Function to transform step returns to new step API irrespective of input API.
|
||||
|
||||
Args:
|
||||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||
is_vector_env (bool): Whether the step_returns are from a vector environment
|
||||
"""
|
||||
if len(step_returns) == 5:
|
||||
return step_returns
|
||||
else:
|
||||
assert len(step_returns) == 4
|
||||
observations, rewards, dones, infos = step_returns
|
||||
|
||||
terminateds = []
|
||||
truncateds = []
|
||||
if not is_vector_env:
|
||||
dones = [dones]
|
||||
|
||||
for i in range(len(dones)):
|
||||
# For every condition, handling - info single env / info vector env (list) / info vector env (dict)
|
||||
|
||||
# TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done`
|
||||
if (not is_vector_env and "TimeLimit.truncated" not in infos) or (
|
||||
is_vector_env
|
||||
and (
|
||||
(
|
||||
isinstance(infos, list)
|
||||
and "TimeLimit.truncated" not in infos[i]
|
||||
) # vector env, list info api
|
||||
or (
|
||||
"TimeLimit.truncated" not in infos
|
||||
or (
|
||||
"TimeLimit.truncated" in infos
|
||||
and not infos["_TimeLimit.truncated"][i]
|
||||
)
|
||||
) # vector env, dict info api, if mask is False, it's the same as TimeLimit.truncated attribute not being present for env 'i'
|
||||
)
|
||||
):
|
||||
|
||||
terminateds.append(dones[i])
|
||||
truncateds.append(False)
|
||||
|
||||
# This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not.
|
||||
elif (
|
||||
infos["TimeLimit.truncated"]
|
||||
if not is_vector_env
|
||||
else (
|
||||
infos["TimeLimit.truncated"][i]
|
||||
if isinstance(infos, dict)
|
||||
else infos[i]["TimeLimit.truncated"]
|
||||
)
|
||||
):
|
||||
assert dones[i]
|
||||
terminateds.append(False)
|
||||
truncateds.append(True)
|
||||
else:
|
||||
# This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated,
|
||||
# but it also exceeded maximum timesteps at the same step.
|
||||
assert dones[i]
|
||||
terminateds.append(True)
|
||||
truncateds.append(True)
|
||||
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0],
|
||||
np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0],
|
||||
infos,
|
||||
)
|
||||
|
||||
|
||||
def step_to_old_api(
|
||||
step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False
|
||||
) -> OldStepType:
|
||||
"""Function to transform step returns to old step API irrespective of input API.
|
||||
|
||||
Args:
|
||||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||
is_vector_env (bool): Whether the step_returns are from a vector environment
|
||||
"""
|
||||
if len(step_returns) == 4:
|
||||
return step_returns
|
||||
else:
|
||||
assert len(step_returns) == 5
|
||||
observations, rewards, terminateds, truncateds, infos = step_returns
|
||||
dones = []
|
||||
if not is_vector_env:
|
||||
terminateds = [terminateds]
|
||||
truncateds = [truncateds]
|
||||
|
||||
n_envs = len(terminateds)
|
||||
|
||||
for i in range(n_envs):
|
||||
dones.append(terminateds[i] or truncateds[i])
|
||||
if truncateds[i]:
|
||||
if is_vector_env:
|
||||
# handle vector infos for dict and list
|
||||
if isinstance(infos, dict):
|
||||
if "TimeLimit.truncated" not in infos:
|
||||
# TODO: This should ideally not be done manually and should use vector_env's _add_info()
|
||||
infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
|
||||
infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
|
||||
|
||||
infos["TimeLimit.truncated"][i] = (
|
||||
not terminateds[i] or infos["TimeLimit.truncated"][i]
|
||||
)
|
||||
infos["_TimeLimit.truncated"][i] = True
|
||||
else:
|
||||
# if vector info is a list
|
||||
infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[
|
||||
i
|
||||
].get("TimeLimit.truncated", False)
|
||||
else:
|
||||
infos["TimeLimit.truncated"] = not terminateds[i] or infos.get(
|
||||
"TimeLimit.truncated", False
|
||||
)
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
np.array(dones, dtype=np.bool_) if is_vector_env else dones[0],
|
||||
infos,
|
||||
)
|
||||
|
||||
|
||||
def step_api_compatibility(
|
||||
step_returns: Union[NewStepType, OldStepType],
|
||||
new_step_api: bool = False,
|
||||
is_vector_env: bool = False,
|
||||
) -> Union[NewStepType, OldStepType]:
|
||||
"""Function to transform step returns to the API specified by `new_step_api` bool.
|
||||
|
||||
Old step API refers to step() method returning (observation, reward, done, info)
|
||||
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
||||
(Refer to docs for details on the API change)
|
||||
|
||||
Args:
|
||||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||
new_step_api (bool): Whether the output should be in new step API or old (False by default)
|
||||
is_vector_env (bool): Whether the step_returns are from a vector environment
|
||||
|
||||
Returns:
|
||||
step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||
|
||||
Examples:
|
||||
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
|
||||
wrapper is written in new API, and the final step output is desired to be in old API.
|
||||
|
||||
>>> obs, rew, done, info = step_api_compatibility(env.step(action))
|
||||
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True)
|
||||
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
|
||||
"""
|
||||
if new_step_api:
|
||||
return step_to_new_api(step_returns, is_vector_env)
|
||||
else:
|
||||
return step_to_old_api(step_returns, is_vector_env)
|
@@ -15,6 +15,7 @@ def make(
|
||||
asynchronous: bool = True,
|
||||
wrappers: Optional[Union[callable, List[callable]]] = None,
|
||||
disable_env_checker: bool = False,
|
||||
new_step_api: bool = False,
|
||||
**kwargs,
|
||||
) -> VectorEnv:
|
||||
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||
@@ -35,6 +36,7 @@ def make(
|
||||
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
|
||||
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
|
||||
disable_env_checker: If to disable the env checker, if True it will only run on the first environment created.
|
||||
new_step_api: If True, the vector environment's step method outputs two booleans `terminated`, `truncated` instead of one `done`.
|
||||
**kwargs: Keywords arguments applied during gym.make
|
||||
|
||||
Returns:
|
||||
@@ -46,7 +48,10 @@ def make(
|
||||
|
||||
def _make_env():
|
||||
env = gym.envs.registration.make(
|
||||
id, disable_env_checker=_disable_env_checker, **kwargs
|
||||
id,
|
||||
disable_env_checker=_disable_env_checker,
|
||||
new_step_api=True,
|
||||
**kwargs,
|
||||
)
|
||||
if wrappers is not None:
|
||||
if callable(wrappers):
|
||||
@@ -65,4 +70,8 @@ def make(
|
||||
env_fns = [
|
||||
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
|
||||
]
|
||||
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
|
||||
return (
|
||||
AsyncVectorEnv(env_fns, new_step_api=new_step_api)
|
||||
if asynchronous
|
||||
else SyncVectorEnv(env_fns, new_step_api=new_step_api)
|
||||
)
|
||||
|
@@ -17,6 +17,7 @@ from gym.error import (
|
||||
CustomSpaceError,
|
||||
NoAsyncCallError,
|
||||
)
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
from gym.vector.utils import (
|
||||
CloudpickleWrapper,
|
||||
clear_mpi_env_vars,
|
||||
@@ -66,6 +67,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
context: Optional[str] = None,
|
||||
daemon: bool = True,
|
||||
worker: Optional[callable] = None,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
@@ -84,7 +86,8 @@ class AsyncVectorEnv(VectorEnv):
|
||||
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
|
||||
so for some environments you may want to have it set to ``False``.
|
||||
worker: If set, then use that worker in a subprocess instead of a default one.
|
||||
Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
|
||||
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
||||
new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done
|
||||
|
||||
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
||||
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
||||
@@ -112,6 +115,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
num_envs=len(env_fns),
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
new_step_api=new_step_api,
|
||||
)
|
||||
|
||||
if self.shared_memory:
|
||||
@@ -338,7 +342,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
|
||||
|
||||
Returns:
|
||||
The batched environment step information, obs, reward, done and info
|
||||
The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
@@ -358,16 +362,17 @@ class AsyncVectorEnv(VectorEnv):
|
||||
f"The call to `step_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
observations_list, rewards, dones, infos = [], [], [], {}
|
||||
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {}
|
||||
successes = []
|
||||
for i, pipe in enumerate(self.parent_pipes):
|
||||
result, success = pipe.recv()
|
||||
obs, rew, done, info = result
|
||||
obs, rew, terminated, truncated, info = step_api_compatibility(result, True)
|
||||
|
||||
successes.append(success)
|
||||
observations_list.append(obs)
|
||||
rewards.append(rew)
|
||||
dones.append(done)
|
||||
terminateds.append(terminated)
|
||||
truncateds.append(truncated)
|
||||
infos = self._add_info(infos, info, i)
|
||||
|
||||
self._raise_if_errors(successes)
|
||||
@@ -380,11 +385,16 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.observations,
|
||||
)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.array(rewards),
|
||||
np.array(dones, dtype=np.bool_),
|
||||
infos,
|
||||
return step_api_compatibility(
|
||||
(
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.array(rewards),
|
||||
np.array(terminateds, dtype=np.bool_),
|
||||
np.array(truncateds, dtype=np.bool_),
|
||||
infos,
|
||||
),
|
||||
self.new_step_api,
|
||||
True,
|
||||
)
|
||||
|
||||
def call_async(self, name: str, *args, **kwargs):
|
||||
@@ -604,11 +614,17 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
pipe.send((observation, True))
|
||||
|
||||
elif command == "step":
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
info["terminal_observation"] = observation
|
||||
(
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
info,
|
||||
) = step_api_compatibility(env.step(data), True)
|
||||
if terminated or truncated:
|
||||
info["final_observation"] = observation
|
||||
observation = env.reset()
|
||||
pipe.send(((observation, reward, done, info), True))
|
||||
pipe.send(((observation, reward, terminated, truncated, info), True))
|
||||
elif command == "seed":
|
||||
env.seed(data)
|
||||
pipe.send((None, True))
|
||||
@@ -673,14 +689,20 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
)
|
||||
pipe.send((None, True))
|
||||
elif command == "step":
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
info["terminal_observation"] = observation
|
||||
(
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
info,
|
||||
) = step_api_compatibility(env.step(data), True)
|
||||
if terminated or truncated:
|
||||
info["final_observation"] = observation
|
||||
observation = env.reset()
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send(((None, reward, done, info), True))
|
||||
pipe.send(((None, reward, terminated, truncated, info), True))
|
||||
elif command == "seed":
|
||||
env.seed(data)
|
||||
pipe.send((None, True))
|
||||
|
@@ -6,6 +6,7 @@ import numpy as np
|
||||
|
||||
from gym import Env
|
||||
from gym.spaces import Space
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
from gym.vector.utils import concatenate, create_empty_array, iterate
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
|
||||
@@ -33,6 +34,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
observation_space: Space = None,
|
||||
action_space: Space = None,
|
||||
copy: bool = True,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
@@ -60,6 +62,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
num_envs=len(self.envs),
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
new_step_api=new_step_api,
|
||||
)
|
||||
|
||||
self._check_spaces()
|
||||
@@ -67,7 +70,8 @@ class SyncVectorEnv(VectorEnv):
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
||||
)
|
||||
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._actions = None
|
||||
|
||||
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
|
||||
@@ -108,7 +112,8 @@ class SyncVectorEnv(VectorEnv):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
self._dones[:] = False
|
||||
self._terminateds[:] = False
|
||||
self._truncateds[:] = False
|
||||
observations = []
|
||||
infos = {}
|
||||
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
|
||||
@@ -151,9 +156,15 @@ class SyncVectorEnv(VectorEnv):
|
||||
"""
|
||||
observations, infos = [], {}
|
||||
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
||||
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
||||
if self._dones[i]:
|
||||
info["terminal_observation"] = observation
|
||||
(
|
||||
observation,
|
||||
self._rewards[i],
|
||||
self._terminateds[i],
|
||||
self._truncateds[i],
|
||||
info,
|
||||
) = step_api_compatibility(env.step(action), True)
|
||||
if self._terminateds[i] or self._truncateds[i]:
|
||||
info["final_observation"] = observation
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
@@ -161,11 +172,16 @@ class SyncVectorEnv(VectorEnv):
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.copy(self._rewards),
|
||||
np.copy(self._dones),
|
||||
infos,
|
||||
return step_api_compatibility(
|
||||
(
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.copy(self._rewards),
|
||||
np.copy(self._terminateds),
|
||||
np.copy(self._truncateds),
|
||||
infos,
|
||||
),
|
||||
new_step_api=self.new_step_api,
|
||||
is_vector_env=True,
|
||||
)
|
||||
|
||||
def call(self, name, *args, **kwargs) -> tuple:
|
||||
|
@@ -1,5 +1,5 @@
|
||||
"""Base class for vectorized environments."""
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -24,7 +24,11 @@ class VectorEnv(gym.Env):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_envs: int, observation_space: gym.Space, action_space: gym.Space
|
||||
self,
|
||||
num_envs: int,
|
||||
observation_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Base class for vectorized environments.
|
||||
|
||||
@@ -32,6 +36,7 @@ class VectorEnv(gym.Env):
|
||||
num_envs: Number of environments in the vectorized environment.
|
||||
observation_space: Observation space of a single environment.
|
||||
action_space: Action space of a single environment.
|
||||
new_step_api (bool): Whether the vector env's step method outputs two boolean arrays (new API) or one boolean array (old API)
|
||||
"""
|
||||
self.num_envs = num_envs
|
||||
self.is_vector_env = True
|
||||
@@ -46,6 +51,13 @@ class VectorEnv(gym.Env):
|
||||
self.single_observation_space = observation_space
|
||||
self.single_action_space = action_space
|
||||
|
||||
self.new_step_api = new_step_api
|
||||
if not self.new_step_api:
|
||||
deprecation(
|
||||
"Initializing vector env in old step API which returns one bool array instead of two. "
|
||||
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
|
||||
)
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
@@ -135,7 +147,7 @@ class VectorEnv(gym.Env):
|
||||
actions: element of :attr:`action_space` Batch of actions.
|
||||
|
||||
Returns:
|
||||
Batch of observations, rewards, done and infos
|
||||
Batch of (observations, rewards, terminateds, truncateds, infos) or (observations, rewards, dones, infos)
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
@@ -143,7 +155,7 @@ class VectorEnv(gym.Env):
|
||||
def call_async(self, name, *args, **kwargs):
|
||||
"""Calls a method name for each parallel environment asynchronously."""
|
||||
|
||||
def call_wait(self, **kwargs) -> List[Any]:
|
||||
def call_wait(self, **kwargs) -> List[Any]: # type: ignore
|
||||
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
||||
|
||||
def call(self, name: str, *args, **kwargs) -> List[Any]:
|
||||
@@ -251,7 +263,7 @@ class VectorEnv(gym.Env):
|
||||
infos[k], infos[f"_{k}"] = info_array, array_mask
|
||||
return infos
|
||||
|
||||
def _init_info_arrays(self, dtype: type) -> np.ndarray:
|
||||
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Initialize the info array.
|
||||
|
||||
Initialize the info array. If the dtype is numeric
|
||||
|
@@ -14,6 +14,7 @@ from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
||||
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
|
||||
from gym.wrappers.rescale_action import RescaleAction
|
||||
from gym.wrappers.resize_observation import ResizeObservation
|
||||
from gym.wrappers.step_api_compatibility import StepAPICompatibility
|
||||
from gym.wrappers.time_aware_observation import TimeAwareObservation
|
||||
from gym.wrappers.time_limit import TimeLimit
|
||||
from gym.wrappers.transform_observation import TransformObservation
|
||||
|
@@ -3,6 +3,7 @@ import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
try:
|
||||
import cv2
|
||||
@@ -37,6 +38,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
grayscale_obs: bool = True,
|
||||
grayscale_newaxis: bool = False,
|
||||
scale_obs: bool = False,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Wrapper for Atari 2600 preprocessing.
|
||||
|
||||
@@ -45,7 +47,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
|
||||
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
|
||||
screen_size (int): resize Atari frame
|
||||
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `done=True` whenever a
|
||||
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
|
||||
life is lost.
|
||||
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
|
||||
is returned.
|
||||
@@ -58,7 +60,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
DependencyNotInstalled: opencv-python package not installed
|
||||
ValueError: Disable frame-skipping in the original env
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
if cv2 is None:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
|
||||
@@ -114,20 +116,22 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Applies the preprocessing for an :meth:`env.step`."""
|
||||
total_reward, done, info = 0.0, False, {}
|
||||
total_reward, terminated, truncated, info = 0.0, False, False, {}
|
||||
|
||||
for t in range(self.frame_skip):
|
||||
_, reward, done, info = self.env.step(action)
|
||||
_, reward, terminated, truncated, info = step_api_compatibility(
|
||||
self.env.step(action), True
|
||||
)
|
||||
total_reward += reward
|
||||
self.game_over = done
|
||||
self.game_over = terminated
|
||||
|
||||
if self.terminal_on_life_loss:
|
||||
new_lives = self.ale.lives()
|
||||
done = done or new_lives < self.lives
|
||||
self.game_over = done
|
||||
terminated = terminated or new_lives < self.lives
|
||||
self.game_over = terminated
|
||||
self.lives = new_lives
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
break
|
||||
if t == self.frame_skip - 2:
|
||||
if self.grayscale_obs:
|
||||
@@ -139,7 +143,10 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||
else:
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
return self._get_obs(), total_reward, done, info
|
||||
return step_api_compatibility(
|
||||
(self._get_obs(), total_reward, terminated, truncated, info),
|
||||
self.new_step_api,
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment using preprocessing."""
|
||||
@@ -156,9 +163,11 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
else 0
|
||||
)
|
||||
for _ in range(noops):
|
||||
_, _, done, step_info = self.env.step(0)
|
||||
_, _, terminated, truncated, step_info = step_api_compatibility(
|
||||
self.env.step(0), True
|
||||
)
|
||||
reset_info.update(step_info)
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
if kwargs.get("return_info", False):
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
else:
|
||||
|
@@ -1,29 +1,40 @@
|
||||
"""Wrapper that autoreset environments when `done=True`."""
|
||||
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
|
||||
import gym
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
class AutoResetWrapper(gym.Wrapper):
|
||||
"""A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
|
||||
|
||||
When calling step causes :meth:`Env.step` to return done, :meth:`Env.reset` is called,
|
||||
and the return format of :meth:`self.step` is as follows: ``(new_obs, terminal_reward, terminal_done, info)``
|
||||
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
|
||||
and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
|
||||
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
|
||||
- ``new_obs`` is the first observation after calling :meth:`self.env.reset`
|
||||
- ``terminal_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
|
||||
- ``terminal_done`` is always True
|
||||
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
|
||||
- ``final_done`` is always True. In the new API, either ``final_terminated`` or ``final_truncated`` is True
|
||||
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
|
||||
with an additional key "terminal_observation" containing the observation returned by the last call to :meth:`self.env.step`
|
||||
and "terminal_info" containing the info dict returned by the last call to :meth:`self.env.step`.
|
||||
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
|
||||
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
|
||||
|
||||
Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns done, a
|
||||
Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
|
||||
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
|
||||
terminal reward and done state from the previous episode.
|
||||
If you need the terminal state from the previous episode, you need to retrieve it via the
|
||||
"terminal_observation" key in the info dict.
|
||||
final reward and done state from the previous episode.
|
||||
If you need the final state from the previous episode, you need to retrieve it via the
|
||||
"final_observation" key in the info dict.
|
||||
Make sure you know what you're doing if you use this wrapper!
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, new_step_api: bool = False):
|
||||
"""A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
|
||||
|
||||
Args:
|
||||
env (gym.Env): The environment to apply the wrapper
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env, new_step_api)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment with action and resets the environment if a done-signal is encountered.
|
||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
@@ -31,22 +42,26 @@ class AutoResetWrapper(gym.Wrapper):
|
||||
Returns:
|
||||
The autoreset environment :meth:`step`
|
||||
"""
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
obs, reward, terminated, truncated, info = step_api_compatibility(
|
||||
self.env.step(action), True
|
||||
)
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
|
||||
new_obs, new_info = self.env.reset(return_info=True)
|
||||
assert (
|
||||
"terminal_observation" not in new_info
|
||||
), 'info dict cannot contain key "terminal_observation" '
|
||||
"final_observation" not in new_info
|
||||
), 'info dict cannot contain key "final_observation" '
|
||||
assert (
|
||||
"terminal_info" not in new_info
|
||||
), 'info dict cannot contain key "terminal_info" '
|
||||
"final_info" not in new_info
|
||||
), 'info dict cannot contain key "final_info" '
|
||||
|
||||
new_info["terminal_observation"] = obs
|
||||
new_info["terminal_info"] = info
|
||||
new_info["final_observation"] = obs
|
||||
new_info["final_info"] = info
|
||||
|
||||
obs = new_obs
|
||||
info = new_info
|
||||
|
||||
return obs, reward, done, info
|
||||
return step_api_compatibility(
|
||||
(obs, reward, terminated, truncated, info), self.new_step_api
|
||||
)
|
||||
|
@@ -26,7 +26,7 @@ class ClipAction(ActionWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
assert isinstance(env.action_space, Box)
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
|
||||
def action(self, action):
|
||||
"""Clips the action within the valid bounds.
|
||||
|
@@ -15,7 +15,7 @@ class PassiveEnvChecker(gym.Wrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
|
||||
assert hasattr(
|
||||
env, "action_space"
|
||||
|
@@ -35,7 +35,7 @@ class FilterObservation(gym.ObservationWrapper):
|
||||
ValueError: If the environment's observation space is not :class:`spaces.Dict`
|
||||
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
|
||||
wrapped_observation_space = env.observation_space
|
||||
if not isinstance(wrapped_observation_space, spaces.Dict):
|
||||
|
@@ -25,7 +25,7 @@ class FlattenObservation(gym.ObservationWrapper):
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
self.observation_space = spaces.flatten_space(env.observation_space)
|
||||
|
||||
def observation(self, observation):
|
||||
|
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import gym
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.spaces import Box
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
class LazyFrames:
|
||||
@@ -122,15 +123,22 @@ class FrameStack(gym.ObservationWrapper):
|
||||
(4, 96, 96, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, num_stack: int, lz4_compress: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
num_stack: int,
|
||||
lz4_compress: bool = False,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
num_stack (int): The number of frames to stack
|
||||
lz4_compress (bool): Use lz4 to compress the frames internally
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
self.num_stack = num_stack
|
||||
self.lz4_compress = lz4_compress
|
||||
|
||||
@@ -163,11 +171,15 @@ class FrameStack(gym.ObservationWrapper):
|
||||
action: The action to step through the environment with
|
||||
|
||||
Returns:
|
||||
Stacked observations, reward, done and information from the environment
|
||||
Stacked observations, reward, terminated, truncated, and information from the environment
|
||||
"""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
observation, reward, terminated, truncated, info = step_api_compatibility(
|
||||
self.env.step(action), True
|
||||
)
|
||||
self.frames.append(observation)
|
||||
return self.observation(None), reward, done, info
|
||||
return step_api_compatibility(
|
||||
(self.observation(), reward, terminated, truncated, info), self.new_step_api
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment with kwargs.
|
||||
|
@@ -28,7 +28,7 @@ class GrayScaleObservation(gym.ObservationWrapper):
|
||||
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
|
||||
Otherwise, they are of shape AxB.
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
self.keep_dim = keep_dim
|
||||
|
||||
assert (
|
||||
|
@@ -45,7 +45,7 @@ class HumanRendering(gym.Wrapper):
|
||||
Args:
|
||||
env: The environment that is being wrapped
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
assert env.render_mode in [
|
||||
"single_rgb_array",
|
||||
"rgb_array",
|
||||
|
@@ -2,6 +2,7 @@
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
|
||||
@@ -54,14 +55,15 @@ class NormalizeObservation(gym.core.Wrapper):
|
||||
newly instantiated or the policy was changed recently.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
|
||||
def __init__(self, env: gym.Env, epsilon: float = 1e-8, new_step_api: bool = False):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
epsilon: A stability parameter that is used when scaling the observations.
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
if self.is_vector_env:
|
||||
@@ -72,12 +74,18 @@ class NormalizeObservation(gym.core.Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment and normalizes the observation."""
|
||||
obs, rews, dones, infos = self.env.step(action)
|
||||
obs, rews, terminateds, truncateds, infos = step_api_compatibility(
|
||||
self.env.step(action), True, self.is_vector_env
|
||||
)
|
||||
if self.is_vector_env:
|
||||
obs = self.normalize(obs)
|
||||
else:
|
||||
obs = self.normalize(np.array([obs]))[0]
|
||||
return obs, rews, dones, infos
|
||||
return step_api_compatibility(
|
||||
(obs, rews, terminateds, truncateds, infos),
|
||||
self.new_step_api,
|
||||
self.is_vector_env,
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment and normalizes the observation."""
|
||||
@@ -117,6 +125,7 @@ class NormalizeReward(gym.core.Wrapper):
|
||||
env: gym.Env,
|
||||
gamma: float = 0.99,
|
||||
epsilon: float = 1e-8,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
@@ -124,8 +133,9 @@ class NormalizeReward(gym.core.Wrapper):
|
||||
env (env): The environment to apply the wrapper
|
||||
epsilon (float): A stability parameter
|
||||
gamma (float): The discount factor that is used in the exponential moving average.
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
self.return_rms = RunningMeanStd(shape=())
|
||||
@@ -135,15 +145,25 @@ class NormalizeReward(gym.core.Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, normalizing the rewards returned."""
|
||||
obs, rews, dones, infos = self.env.step(action)
|
||||
obs, rews, terminateds, truncateds, infos = step_api_compatibility(
|
||||
self.env.step(action), True, self.is_vector_env
|
||||
)
|
||||
if not self.is_vector_env:
|
||||
rews = np.array([rews])
|
||||
self.returns = self.returns * self.gamma + rews
|
||||
rews = self.normalize(rews)
|
||||
if not self.is_vector_env:
|
||||
dones = terminateds or truncateds
|
||||
else:
|
||||
dones = np.bitwise_or(terminateds, truncateds)
|
||||
self.returns[dones] = 0.0
|
||||
if not self.is_vector_env:
|
||||
rews = rews[0]
|
||||
return obs, rews, dones, infos
|
||||
return step_api_compatibility(
|
||||
(obs, rews, terminateds, truncateds, infos),
|
||||
self.new_step_api,
|
||||
self.is_vector_env,
|
||||
)
|
||||
|
||||
def normalize(self, rews):
|
||||
"""Normalizes the rewards with the running mean rewards and their variance."""
|
||||
|
@@ -26,7 +26,7 @@ class OrderEnforcing(gym.Wrapper):
|
||||
env: The environment to wrap
|
||||
disable_render_order_enforcing: If to disable render order enforcing
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
self._has_reset: bool = False
|
||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||
|
||||
|
@@ -77,7 +77,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
||||
specified ``pixel_keys``.
|
||||
TypeError: When an unexpected pixel type is used
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
|
||||
# Avoid side-effects that occur when render_kwargs is manipulated
|
||||
render_kwargs = copy.deepcopy(render_kwargs)
|
||||
|
@@ -6,6 +6,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
def add_vector_episode_statistics(
|
||||
@@ -76,14 +77,15 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
length_queue: The lengths of the last ``deque_size``-many episodes
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, deque_size: int = 100):
|
||||
def __init__(self, env: gym.Env, deque_size: int = 100, new_step_api: bool = False):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.t0 = time.perf_counter()
|
||||
self.episode_count = 0
|
||||
@@ -102,18 +104,26 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, recording the episode statistics."""
|
||||
observations, rewards, dones, infos = super().step(action)
|
||||
(
|
||||
observations,
|
||||
rewards,
|
||||
terminateds,
|
||||
truncateds,
|
||||
infos,
|
||||
) = step_api_compatibility(self.env.step(action), True, self.is_vector_env)
|
||||
assert isinstance(
|
||||
infos, dict
|
||||
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
||||
self.episode_returns += rewards
|
||||
self.episode_lengths += 1
|
||||
if not self.is_vector_env:
|
||||
dones = [dones]
|
||||
dones = list(dones)
|
||||
terminateds = [terminateds]
|
||||
truncateds = [truncateds]
|
||||
terminateds = list(terminateds)
|
||||
truncateds = list(truncateds)
|
||||
|
||||
for i in range(len(dones)):
|
||||
if dones[i]:
|
||||
for i in range(len(terminateds)):
|
||||
if terminateds[i] or truncateds[i]:
|
||||
episode_return = self.episode_returns[i]
|
||||
episode_length = self.episode_lengths[i]
|
||||
episode_info = {
|
||||
@@ -134,9 +144,14 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
self.episode_count += 1
|
||||
self.episode_returns[i] = 0
|
||||
self.episode_lengths[i] = 0
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
dones if self.is_vector_env else dones[0],
|
||||
infos,
|
||||
return step_api_compatibility(
|
||||
(
|
||||
observations,
|
||||
rewards,
|
||||
terminateds if self.is_vector_env else terminateds[0],
|
||||
truncateds if self.is_vector_env else truncateds[0],
|
||||
infos,
|
||||
),
|
||||
self.new_step_api,
|
||||
self.is_vector_env,
|
||||
)
|
||||
|
@@ -4,6 +4,7 @@ from typing import Callable, Optional
|
||||
|
||||
import gym
|
||||
from gym import logger
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
from gym.wrappers.monitoring import video_recorder
|
||||
|
||||
|
||||
@@ -32,7 +33,7 @@ class RecordVideo(gym.Wrapper):
|
||||
They should be functions returning a boolean that indicates whether a recording should be started at the
|
||||
current episode or step, respectively.
|
||||
If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed.
|
||||
By default, the recording will be stopped once a `done` signal has been emitted by the environment. However, you can
|
||||
By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can
|
||||
also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for
|
||||
``video_length``.
|
||||
"""
|
||||
@@ -45,6 +46,7 @@ class RecordVideo(gym.Wrapper):
|
||||
step_trigger: Callable[[int], bool] = None,
|
||||
video_length: int = 0,
|
||||
name_prefix: str = "rl-video",
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Wrapper records videos of rollouts.
|
||||
|
||||
@@ -56,8 +58,9 @@ class RecordVideo(gym.Wrapper):
|
||||
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
|
||||
Otherwise, snippets of the specified length are captured
|
||||
name_prefix (str): Will be prepended to the filename of the recordings
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
|
||||
if episode_trigger is None and step_trigger is None:
|
||||
episode_trigger = capped_cubic_video_schedule
|
||||
@@ -83,7 +86,8 @@ class RecordVideo(gym.Wrapper):
|
||||
self.video_length = video_length
|
||||
|
||||
self.recording = False
|
||||
self.done = False
|
||||
self.terminated = False
|
||||
self.truncated = False
|
||||
self.recorded_frames = 0
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
self.episode_id = 0
|
||||
@@ -91,7 +95,8 @@ class RecordVideo(gym.Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment using kwargs and then starts recording if video enabled."""
|
||||
observations = super().reset(**kwargs)
|
||||
self.done = False
|
||||
self.terminated = False
|
||||
self.truncated = False
|
||||
if self.recording:
|
||||
assert self.video_recorder is not None
|
||||
self.video_recorder.frames = []
|
||||
@@ -132,18 +137,26 @@ class RecordVideo(gym.Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
|
||||
observations, rewards, dones, infos = super().step(action)
|
||||
(
|
||||
observations,
|
||||
rewards,
|
||||
terminateds,
|
||||
truncateds,
|
||||
infos,
|
||||
) = step_api_compatibility(self.env.step(action), True, self.is_vector_env)
|
||||
|
||||
if not self.done:
|
||||
if not (self.terminated or self.truncated):
|
||||
# increment steps and episodes
|
||||
self.step_id += 1
|
||||
if not self.is_vector_env:
|
||||
if dones:
|
||||
if terminateds or truncateds:
|
||||
self.episode_id += 1
|
||||
self.done = True
|
||||
elif dones[0]:
|
||||
self.terminated = terminateds
|
||||
self.truncated = truncateds
|
||||
elif terminateds[0] or truncateds[0]:
|
||||
self.episode_id += 1
|
||||
self.done = True
|
||||
self.terminated = terminateds[0]
|
||||
self.truncated = truncateds[0]
|
||||
|
||||
if self.recording:
|
||||
assert self.video_recorder is not None
|
||||
@@ -154,15 +167,19 @@ class RecordVideo(gym.Wrapper):
|
||||
self.close_video_recorder()
|
||||
else:
|
||||
if not self.is_vector_env:
|
||||
if dones:
|
||||
if terminateds or truncateds:
|
||||
self.close_video_recorder()
|
||||
elif dones[0]:
|
||||
elif terminateds[0] or truncateds[0]:
|
||||
self.close_video_recorder()
|
||||
|
||||
elif self._video_enabled():
|
||||
self.start_video_recorder()
|
||||
|
||||
return observations, rewards, dones, infos
|
||||
return step_api_compatibility(
|
||||
(observations, rewards, terminateds, truncateds, infos),
|
||||
self.new_step_api,
|
||||
self.is_vector_env,
|
||||
)
|
||||
|
||||
def close_video_recorder(self):
|
||||
"""Closes the video recorder if currently recording."""
|
||||
|
@@ -45,7 +45,7 @@ class RescaleAction(gym.ActionWrapper):
|
||||
), f"expected Box action space, got {type(env.action_space)}"
|
||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
||||
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
self.min_action = (
|
||||
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
|
||||
)
|
||||
|
@@ -32,7 +32,7 @@ class ResizeObservation(gym.ObservationWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
shape: The shape of the resized observations
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
if isinstance(shape, int):
|
||||
shape = (shape, shape)
|
||||
assert all(x > 0 for x in shape), shape
|
||||
|
57
gym/wrappers/step_api_compatibility.py
Normal file
57
gym/wrappers/step_api_compatibility.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
|
||||
import gym
|
||||
from gym.logger import deprecation
|
||||
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
|
||||
|
||||
|
||||
class StepAPICompatibility(gym.Wrapper):
|
||||
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
||||
|
||||
Old step API refers to step() method returning (observation, reward, done, info)
|
||||
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
||||
(Refer to docs for details on the API change)
|
||||
|
||||
This wrapper is to be used to ease transition to new API and for backward compatibility.
|
||||
|
||||
Args:
|
||||
env (gym.Env): the env to wrap. Can be in old or new API
|
||||
new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default)
|
||||
|
||||
Examples:
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env # wrapper applied by default, set to old API
|
||||
<TimeLimit<OrderEnforcing<StepAPICompatibility<CartPoleEnv<CartPole-v1>>>>>
|
||||
>>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API
|
||||
>>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, new_step_api=False):
|
||||
"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
||||
|
||||
Args:
|
||||
env (gym.Env): the env to wrap. Can be in old or new API
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env, new_step_api)
|
||||
self.new_step_api = new_step_api
|
||||
if not self.new_step_api:
|
||||
deprecation(
|
||||
"Initializing environment in old step API which returns one bool instead of two. "
|
||||
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, returning 5 or 4 items depending on `new_step_api`.
|
||||
|
||||
Args:
|
||||
action: action to step through the environment with
|
||||
|
||||
Returns:
|
||||
(observation, reward, terminated, truncated, info) or (observation, reward, done, info)
|
||||
"""
|
||||
step_returns = self.env.step(action)
|
||||
if self.new_step_api:
|
||||
return step_to_new_api(step_returns)
|
||||
else:
|
||||
return step_to_old_api(step_returns)
|
@@ -3,6 +3,7 @@ import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
class TimeAwareObservation(gym.ObservationWrapper):
|
||||
@@ -21,18 +22,20 @@ class TimeAwareObservation(gym.ObservationWrapper):
|
||||
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
def __init__(self, env: gym.Env, new_step_api: bool = False):
|
||||
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert env.observation_space.dtype == np.float32
|
||||
low = np.append(self.observation_space.low, 0.0)
|
||||
high = np.append(self.observation_space.high, np.inf)
|
||||
self.observation_space = Box(low, high, dtype=np.float32)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
|
||||
def observation(self, observation):
|
||||
"""Adds to the observation with the current time step.
|
||||
@@ -55,7 +58,9 @@ class TimeAwareObservation(gym.ObservationWrapper):
|
||||
The environment's step using the action.
|
||||
"""
|
||||
self.t += 1
|
||||
return super().step(action)
|
||||
return step_api_compatibility(
|
||||
super().step(action), self.new_step_api, self.is_vector_env
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment setting the time to zero.
|
||||
|
@@ -2,16 +2,20 @@
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
class TimeLimit(gym.Wrapper):
|
||||
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
|
||||
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
|
||||
|
||||
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
|
||||
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
|
||||
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
|
||||
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
||||
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
|
||||
|
||||
(deprecated)
|
||||
This information is passed through ``info`` that is returned when `done`-signal was issued.
|
||||
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
|
||||
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
|
||||
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. This will be removed in favour
|
||||
of only issuing a `truncated` signal in future versions.
|
||||
|
||||
Example:
|
||||
>>> from gym.envs.classic_control import CartPoleEnv
|
||||
@@ -20,14 +24,20 @@ class TimeLimit(gym.Wrapper):
|
||||
>>> env = TimeLimit(env, max_episode_steps=1000)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
max_episode_steps: Optional[int] = None,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
if max_episode_steps is None and self.env.spec is not None:
|
||||
max_episode_steps = env.spec.max_episode_steps
|
||||
if self.env.spec is not None:
|
||||
@@ -46,15 +56,19 @@ class TimeLimit(gym.Wrapper):
|
||||
when truncated (the number of steps elapsed >= max episode steps) or
|
||||
"TimeLimit.truncated"=False if the environment terminated
|
||||
"""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
observation, reward, terminated, truncated, info = step_api_compatibility(
|
||||
self.env.step(action),
|
||||
True,
|
||||
)
|
||||
self._elapsed_steps += 1
|
||||
|
||||
if self._elapsed_steps >= self._max_episode_steps:
|
||||
# TimeLimit.truncated key may have been already set by the environment
|
||||
# do not overwrite it
|
||||
episode_truncated = not done or info.get("TimeLimit.truncated", False)
|
||||
info["TimeLimit.truncated"] = episode_truncated
|
||||
done = True
|
||||
return observation, reward, done, info
|
||||
truncated = True
|
||||
|
||||
return step_api_compatibility(
|
||||
(observation, reward, terminated, truncated, info),
|
||||
self.new_step_api,
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
||||
|
@@ -27,7 +27,7 @@ class TransformObservation(gym.ObservationWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
f: A function that transforms the observation
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
assert callable(f)
|
||||
self.f = f
|
||||
|
||||
|
@@ -16,7 +16,7 @@ class TransformReward(RewardWrapper):
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TransformReward(env, lambda r: 0.01*r)
|
||||
>>> env.reset()
|
||||
>>> observation, reward, done, info = env.step(env.action_space.sample())
|
||||
>>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
|
||||
>>> reward
|
||||
0.01
|
||||
"""
|
||||
@@ -28,7 +28,7 @@ class TransformReward(RewardWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
f: A function that transforms the reward
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api=True)
|
||||
assert callable(f)
|
||||
self.f = f
|
||||
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from typing import List
|
||||
|
||||
import gym
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
class VectorListInfo(gym.Wrapper):
|
||||
@@ -29,23 +30,30 @@ class VectorListInfo(gym.Wrapper):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
def __init__(self, env, new_step_api=False):
|
||||
"""This wrapper will convert the info into the list format.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
assert getattr(
|
||||
env, "is_vector_env", False
|
||||
), "This wrapper can only be used in vectorized environments."
|
||||
super().__init__(env)
|
||||
super().__init__(env, new_step_api)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, convert dict info to list."""
|
||||
observation, reward, done, infos = self.env.step(action)
|
||||
observation, reward, terminated, truncated, infos = step_api_compatibility(
|
||||
self.env.step(action), True, True
|
||||
)
|
||||
list_info = self._convert_info_to_list(infos)
|
||||
|
||||
return observation, reward, done, list_info
|
||||
return step_api_compatibility(
|
||||
(observation, reward, terminated, truncated, list_info),
|
||||
self.new_step_api,
|
||||
True,
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment using kwargs."""
|
||||
|
@@ -34,5 +34,8 @@ reportPrivateUsage = "warning"
|
||||
reportUntypedFunctionDecorator = "none"
|
||||
reportMissingTypeStubs = false
|
||||
reportUnboundVariable = "warning"
|
||||
reportGeneralTypeIssues = "none"
|
||||
reportInvalidTypeVarUse = "none"
|
||||
reportGeneralTypeIssues ="none"
|
||||
reportInvalidTypeVarUse = "none"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed
|
||||
|
@@ -112,7 +112,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
||||
zip(env.action_space.bounded_above, env.action_space.bounded_below)
|
||||
):
|
||||
if is_upper_bound:
|
||||
obs, _, _, _ = env.step(upper_bounds)
|
||||
obs, _, _, _, _ = env.step(
|
||||
upper_bounds
|
||||
) # `env` is unwrapped, and in new step API
|
||||
oob_action = upper_bounds.copy()
|
||||
oob_action[i] += np.cast[dtype](OOB_VALUE)
|
||||
|
||||
@@ -122,7 +124,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
||||
assert np.alltrue(obs == oob_obs)
|
||||
|
||||
if is_lower_bound:
|
||||
obs, _, _, _ = env.step(lower_bounds)
|
||||
obs, _, _, _, _ = env.step(
|
||||
lower_bounds
|
||||
) # `env` is unwrapped, and in new step API
|
||||
oob_action = lower_bounds.copy()
|
||||
oob_action[i] -= np.cast[dtype](OOB_VALUE)
|
||||
|
||||
|
@@ -144,7 +144,7 @@ def test_taxi_encode_decode():
|
||||
assert (
|
||||
env.encode(*env.decode(state)) == state
|
||||
), f"state={state}, encode(decode(state))={env.encode(*env.decode(state))}"
|
||||
state, _, _, _ = env.step(env.action_space.sample())
|
||||
state, _, _, _, _ = env.step(env.action_space.sample())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@@ -172,7 +172,9 @@ def test_make_render_mode():
|
||||
assert env.render_mode == valid_render_modes[0]
|
||||
env.close()
|
||||
|
||||
assert len(warnings) == 0
|
||||
for warning in warnings.list:
|
||||
if not re.compile(".*step API.*").match(warning.message.args[0]):
|
||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||
|
||||
# Make sure that native rendering is used when possible
|
||||
env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True)
|
||||
|
91
tests/utils/test_terminated_truncated.py
Normal file
91
tests/utils/test_terminated_truncated.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym.spaces import Discrete
|
||||
from gym.vector import AsyncVectorEnv, SyncVectorEnv
|
||||
from gym.wrappers import TimeLimit
|
||||
|
||||
|
||||
# An environment where termination happens after 20 steps
|
||||
class DummyEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Discrete(2)
|
||||
self.terminal_timestep = 20
|
||||
|
||||
self.timestep = 0
|
||||
|
||||
def step(self, action):
|
||||
self.timestep += 1
|
||||
terminated = True if self.timestep >= self.terminal_timestep else False
|
||||
truncated = False
|
||||
|
||||
return 0, 0, terminated, truncated, {}
|
||||
|
||||
def reset(self):
|
||||
self.timestep = 0
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("time_limit", [10, 20, 30])
|
||||
def test_terminated_truncated(time_limit):
|
||||
test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True)
|
||||
|
||||
terminated = False
|
||||
truncated = False
|
||||
test_env.reset()
|
||||
while not (terminated or truncated):
|
||||
_, _, terminated, truncated, _ = test_env.step(0)
|
||||
|
||||
if test_env.terminal_timestep < time_limit:
|
||||
assert terminated
|
||||
assert not truncated
|
||||
elif test_env.terminal_timestep == time_limit:
|
||||
assert (
|
||||
terminated
|
||||
), "`terminated` should be True even when termination and truncation happen at the same step"
|
||||
assert (
|
||||
truncated
|
||||
), "`truncated` should be True even when termination and truncation occur at same step "
|
||||
else:
|
||||
assert not terminated
|
||||
assert truncated
|
||||
|
||||
|
||||
def test_terminated_truncated_vector():
|
||||
env0 = TimeLimit(DummyEnv(), 10, new_step_api=True)
|
||||
env1 = TimeLimit(DummyEnv(), 20, new_step_api=True)
|
||||
env2 = TimeLimit(DummyEnv(), 30, new_step_api=True)
|
||||
|
||||
async_env = AsyncVectorEnv(
|
||||
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
|
||||
)
|
||||
async_env.reset()
|
||||
terminateds = [False, False, False]
|
||||
truncateds = [False, False, False]
|
||||
counter = 0
|
||||
while not all([x or y for x, y in zip(terminateds, truncateds)]):
|
||||
counter += 1
|
||||
_, _, terminateds, truncateds, _ = async_env.step(
|
||||
async_env.action_space.sample()
|
||||
)
|
||||
print(counter)
|
||||
assert counter == 20
|
||||
assert all(terminateds == [False, True, True])
|
||||
assert all(truncateds == [True, True, False])
|
||||
|
||||
sync_env = SyncVectorEnv(
|
||||
[lambda: env0, lambda: env1, lambda: env2], new_step_api=True
|
||||
)
|
||||
sync_env.reset()
|
||||
terminateds = [False, False, False]
|
||||
truncateds = [False, False, False]
|
||||
counter = 0
|
||||
while not all([x or y for x, y in zip(terminateds, truncateds)]):
|
||||
counter += 1
|
||||
_, _, terminateds, truncateds, _ = sync_env.step(
|
||||
async_env.action_space.sample()
|
||||
)
|
||||
assert counter == 20
|
||||
assert all(terminateds == [False, True, True])
|
||||
assert all(truncateds == [True, True, False])
|
88
tests/vector/test_step_compatibility_vector.py
Normal file
88
tests/vector/test_step_compatibility_vector.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym.spaces import Discrete
|
||||
from gym.vector import AsyncVectorEnv, SyncVectorEnv
|
||||
|
||||
|
||||
class OldStepEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
return 0
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
rew = 0
|
||||
done = False
|
||||
info = {}
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class NewStepEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
return 0
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
rew = 0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info = {}
|
||||
return obs, rew, terminated, truncated, info
|
||||
|
||||
|
||||
@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv])
|
||||
def test_vector_step_compatibility_new_env(VecEnv):
|
||||
|
||||
envs = [
|
||||
OldStepEnv(),
|
||||
NewStepEnv(),
|
||||
]
|
||||
|
||||
vec_env = VecEnv([lambda: env for env in envs])
|
||||
vec_env.reset()
|
||||
step_returns = vec_env.step([0, 0])
|
||||
assert len(step_returns) == 4
|
||||
_, _, dones, _ = step_returns
|
||||
assert dones.dtype == np.bool_
|
||||
vec_env.close()
|
||||
|
||||
vec_env = VecEnv([lambda: env for env in envs], new_step_api=True)
|
||||
vec_env.reset()
|
||||
step_returns = vec_env.step([0, 0])
|
||||
assert len(step_returns) == 5
|
||||
_, _, terminateds, truncateds, _ = step_returns
|
||||
assert terminateds.dtype == np.bool_
|
||||
assert truncateds.dtype == np.bool_
|
||||
vec_env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_bool", [True, False])
|
||||
def test_vector_step_compatibility_existing(async_bool):
|
||||
|
||||
env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool)
|
||||
env.reset()
|
||||
step_returns = env.step(env.action_space.sample())
|
||||
assert len(step_returns) == 4
|
||||
_, _, dones, _ = step_returns
|
||||
assert dones.dtype == np.bool_
|
||||
env.close()
|
||||
|
||||
env = gym.vector.make(
|
||||
"CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True
|
||||
)
|
||||
env.reset()
|
||||
step_returns = env.step(env.action_space.sample())
|
||||
assert len(step_returns) == 5
|
||||
_, _, terminateds, truncateds, _ = step_returns
|
||||
assert terminateds.dtype == np.bool_
|
||||
assert truncateds.dtype == np.bool_
|
||||
env.close()
|
@@ -36,10 +36,10 @@ def test_vector_env_equal(shared_memory):
|
||||
# fmt: on
|
||||
|
||||
if any(sync_dones):
|
||||
assert "terminal_observation" in async_infos
|
||||
assert "_terminal_observation" in async_infos
|
||||
assert "terminal_observation" in sync_infos
|
||||
assert "_terminal_observation" in sync_infos
|
||||
assert "final_observation" in async_infos
|
||||
assert "_final_observation" in async_infos
|
||||
assert "final_observation" in sync_infos
|
||||
assert "_final_observation" in sync_infos
|
||||
|
||||
assert np.all(async_observations == sync_observations)
|
||||
assert np.all(async_rewards == sync_rewards)
|
||||
|
@@ -22,18 +22,18 @@ def test_vector_env_info(asynchronous):
|
||||
action = env.action_space.sample()
|
||||
_, _, dones, infos = env.step(action)
|
||||
if any(dones):
|
||||
assert len(infos["terminal_observation"]) == NUM_ENVS
|
||||
assert len(infos["_terminal_observation"]) == NUM_ENVS
|
||||
assert len(infos["final_observation"]) == NUM_ENVS
|
||||
assert len(infos["_final_observation"]) == NUM_ENVS
|
||||
|
||||
assert isinstance(infos["terminal_observation"], np.ndarray)
|
||||
assert isinstance(infos["_terminal_observation"], np.ndarray)
|
||||
assert isinstance(infos["final_observation"], np.ndarray)
|
||||
assert isinstance(infos["_final_observation"], np.ndarray)
|
||||
|
||||
for i, done in enumerate(dones):
|
||||
if done:
|
||||
assert infos["_terminal_observation"][i]
|
||||
assert infos["_final_observation"][i]
|
||||
else:
|
||||
assert not infos["_terminal_observation"][i]
|
||||
assert infos["terminal_observation"][i] is None
|
||||
assert not infos["_final_observation"][i]
|
||||
assert infos["final_observation"][i] is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
|
||||
@@ -49,8 +49,8 @@ def test_vector_env_info_concurrent_termination(concurrent_ends):
|
||||
for i, done in enumerate(dones):
|
||||
if i < concurrent_ends:
|
||||
assert done
|
||||
assert infos["_terminal_observation"][i]
|
||||
assert infos["_final_observation"][i]
|
||||
else:
|
||||
assert not infos["_terminal_observation"][i]
|
||||
assert infos["terminal_observation"][i] is None
|
||||
assert not infos["_final_observation"][i]
|
||||
assert infos["final_observation"][i] is None
|
||||
return
|
||||
|
@@ -136,8 +136,8 @@ def test_autoreset_wrapper_autoreset():
|
||||
assert reward == 1
|
||||
assert info == {
|
||||
"count": 0,
|
||||
"terminal_observation": np.array([3]),
|
||||
"terminal_info": {"count": 3},
|
||||
"final_observation": np.array([3]),
|
||||
"final_info": {"count": 3},
|
||||
}
|
||||
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
77
tests/wrappers/test_step_compatibility.py
Normal file
77
tests/wrappers/test_step_compatibility.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym.spaces import Discrete
|
||||
from gym.wrappers import StepAPICompatibility
|
||||
|
||||
|
||||
class OldStepEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Discrete(2)
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
rew = 0
|
||||
done = False
|
||||
info = {}
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class NewStepEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Discrete(2)
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
rew = 0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info = {}
|
||||
return obs, rew, terminated, truncated, info
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
||||
def test_step_compatibility_to_new_api(env):
|
||||
env = StepAPICompatibility(env(), True)
|
||||
step_returns = env.step(0)
|
||||
_, _, terminated, truncated, _ = step_returns
|
||||
assert isinstance(terminated, bool)
|
||||
assert isinstance(truncated, bool)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
||||
@pytest.mark.parametrize("new_step_api", [None, False])
|
||||
def test_step_compatibility_to_old_api(env, new_step_api):
|
||||
if new_step_api is None:
|
||||
env = StepAPICompatibility(env()) # default behavior is to retain old API
|
||||
else:
|
||||
env = StepAPICompatibility(env(), new_step_api)
|
||||
step_returns = env.step(0)
|
||||
assert len(step_returns) == 4
|
||||
_, _, done, _ = step_returns
|
||||
assert isinstance(done, bool)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("new_step_api", [None, True, False])
|
||||
def test_step_compatibility_in_make(new_step_api):
|
||||
if new_step_api is None:
|
||||
with pytest.warns(
|
||||
DeprecationWarning, match="Initializing environment in old step API"
|
||||
):
|
||||
env = gym.make("CartPole-v1")
|
||||
else:
|
||||
env = gym.make("CartPole-v1", new_step_api=new_step_api)
|
||||
|
||||
env.reset()
|
||||
step_returns = env.step(0)
|
||||
if new_step_api:
|
||||
assert len(step_returns) == 5
|
||||
_, _, terminated, truncated, _ = step_returns
|
||||
assert isinstance(terminated, bool)
|
||||
assert isinstance(truncated, bool)
|
||||
else:
|
||||
assert len(step_returns) == 4
|
||||
_, _, done, _ = step_returns
|
||||
assert isinstance(done, bool)
|
@@ -32,9 +32,9 @@ def test_info_to_list():
|
||||
_, _, dones, list_info = wrapped_env.step(action)
|
||||
for i, done in enumerate(dones):
|
||||
if done:
|
||||
assert "terminal_observation" in list_info[i]
|
||||
assert "final_observation" in list_info[i]
|
||||
else:
|
||||
assert "terminal_observation" not in list_info[i]
|
||||
assert "final_observation" not in list_info[i]
|
||||
|
||||
|
||||
def test_info_to_list_statistics():
|
||||
|
Reference in New Issue
Block a user