Update Docs with New Step API (#23)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
Arjun KG
2022-09-27 21:50:22 +05:30
committed by GitHub
parent 95da6c5714
commit 48b966233c
8 changed files with 125 additions and 42 deletions

View File

@@ -72,7 +72,7 @@ title: Vector
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset()
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, terminated, truncated, infos = envs.step(actions)
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
array([[ 0.00122802, 0.16228443, 0.02521779, -0.23700266],
@@ -81,7 +81,9 @@ array([[ 0.00122802, 0.16228443, 0.02521779, -0.23700266],
dtype=float32)
>>> rewards
array([1., 1., 1.])
>>> terminated
>>> termination
array([False, False, False])
>>> termination
array([False, False, False])
>>> infos
{}

View File

@@ -132,28 +132,28 @@ class ClipReward(gym.RewardWrapper):
Some users may want a wrapper which will automatically reset its wrapped environment when its wrapped environment reaches the done state. An advantage of this environment is that it will never produce undefined behavior as standard gymnasium environments do when stepping beyond the done state.
When calling step causes `self.env.step()` to return `done=True`,
When calling step causes `self.env.step()` to return `(terminated or truncated)=True`,
`self.env.reset()` is called,
and the return format of `self.step()` is as follows:
```python
new_obs, terminal_reward, terminated, truncated, info
new_obs, final_reward, final_terminated, final_truncated, info
```
`new_obs` is the first observation after calling `self.env.reset()`,
`terminal_reward` is the reward after calling `self.env.step()`,
`final_reward` is the reward after calling `self.env.step()`,
prior to calling `self.env.reset()`
`terminated or truncated` is always `True`
The expression `(final_terminated or final_truncated)` is always `True`
`info` is a dict containing all the keys from the info dict returned by
the call to `self.env.reset()`, with additional keys `terminal_observation`
the call to `self.env.reset()`, with additional keys `final_observation`
containing the observation returned by the last call to `self.env.step()`
and `terminal_info` containing the info dict returned by the last call
and `final_info` containing the info dict returned by the last call
to `self.env.step()`.
If `done` is not true when `self.env.step()` is called, `self.step()` returns
If `(terminated or truncated)` is not true when `self.env.step()` is called, `self.step()` returns
```python
obs, reward, terminated, truncated, info
@@ -180,7 +180,7 @@ that the when `self.env.step()` returns `done`, a
new observation from after calling `self.env.reset()` is returned
by `self.step()` alongside the terminal reward and done state from the
previous episode . If you need the terminal state from the previous
episode, you need to retrieve it via the the `terminal_observation` key
episode, you need to retrieve it via the the `final_observation` key
in the info dict. Make sure you know what you're doing if you
use this wrapper!
```

View File

@@ -37,10 +37,8 @@ to a specific point in space. If it succeeds in doing this (or makes some progre
alongside the observation for this timestep. The reward may also be negative or 0, if the agent did not yet succeed (or did not make any progress).
The agent will then be trained to maximize the reward it accumulates over many timesteps.
After some timesteps, the environment may enter a terminal state. For instance, the robot may have crashed! In that case,
we want to reset the environment to a new initial state. The environment issues a done signal to the agent if it enters such a terminal state.
Not all done signals must be triggered by a "catastrophic failure": Sometimes we also want to issue a done signal after
a fixed number of timesteps, or if the agent has succeeded in completing some task in the environment.
After some timesteps, the environment may enter a terminal state. For instance, the robot may have crashed, or the agent may have succeeded in completing a task. In that case, we want to reset the environment to a new initial state. The environment issues a terminated signal to the agent if it enters such a terminal state. Sometimes we also want to end the episode after a fixed number of timesteps, in this case, the environment issues a truncated signal.
This is a new change in API (v0.26 onwards). Earlier a common done signal was issued for an episode ending via any means. This is now changed in favour of issuing two signals - terminated and truncated.
Let's see what the agent-environment loop looks like in Gymnasium.
This example will run an instance of `LunarLander-v2` environment for 1000 timesteps. Since we pass `render_mode="human"`, you should see a window pop up rendering the environment.
@@ -74,6 +72,33 @@ the format of valid observations is specified by `env.observation_space`.
In the example above we sampled random actions via `env.action_space.sample()`. Note that we need to seed the action space separately from the
environment to ensure reproducible samples.
### Change in env.step API
Previously, the step method returned only one boolean - `done`. This is being deprecated in favour of returning two booleans `terminated` and `truncated` (v0.26 onwards).
`terminated` signal is set to `True` when the core environment terminates inherently because of task completion, failure etc. a condition defined in the MDP.
`truncated` signal is set to `True` when the episode ends specifically because of a time-limit or a condition not inherent to the environment (not defined in the MDP).
It is possible for `terminated=True` and `truncated=True` to occur at the same time when termination and truncation occur at the same step.
This is explained in detail in the `Handling Time Limits` section.
#### Backward compatibility
Gym will retain support for the old API through compatibility wrappers.
Users can toggle the old API through `make` by setting `apply_api_compatibility=True`.
```python
env = gym.make("CartPole-v1", apply_api_compatibility=True)
```
This can also be done explicitly through a wrapper:
```python
from gymasium.wrappers import StepCompatibility
env = StepCompatibility(CustomEnv(), output_truncation_bool=False)
```
For more details see the wrappers section.
## Checking API-Conformity
If you have implemented a custom environment and would like to perform a sanity check to make sure that it conforms to
the API, you can run:
@@ -169,7 +194,7 @@ reward based on data in `info`). Such wrappers
can be implemented by inheriting from `Wrapper`.
Gymnasium already provides many commonly used wrappers for you. Some examples:
- `TimeLimit`: Issue a done signal if a maximum number of timesteps has been exceeded (or the base environment has issued a done signal).
- `TimeLimit`: Issue a truncated signal if a maximum number of timesteps has been exceeded (or the base environment has issued a truncated signal).
- `ClipAction`: Clip the action such that it lies in the action space (of type `Box`).
- `RescaleAction`: Rescale actions to lie in a specified interval
- `TimeAwareObservation`: Add information about the index of timestep to observation. In some cases helpful to ensure that transitions are Markov.
@@ -212,7 +237,7 @@ where we obtain the corresponding key ID constants from pygame. If the `key_to_a
Furthermore, if you wish to plot real time statistics as you play, you can use `gymnasium.utils.play.PlayPlot`. Here's some sample code for plotting the reward for last 5 second of gameplay:
```python
def callback(obs_t, obs_tp1, action, rew, done, info):
def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
return [rew,]
plotter = PlayPlot(callback, 30 * 5, ["reward"])
env = gymnasium.make("Pong-v0")

View File

@@ -0,0 +1,53 @@
# Handling Time Limits
In using Gym environments with reinforcement learning code, a common problem observed is how time limits are incorrectly handled. The `done` signal received (in previous versions of gym < 0.26) from `env.step` indicated whether an episode has ended. However, this signal did not distinguish between whether the episode ended due to `termination` or `truncation`.
### Termination
Termination refers to the episode ending after reaching a terminal state that is defined as part of the environment definition. Examples are - task success, task failure, robot falling down etc. Notably this also includes episode ending in finite-horizon environments due to a time-limit inherent to the environment. Note that to preserve Markov property, a representation of the remaining time must be present in the agent's observation in finite-horizon environments. [(Reference)](https://arxiv.org/abs/1712.00378)
### Truncation
Truncation refers to the episode ending after an externally defined condition (that is outside the scope of the Markov Decision Process). This could be a time-limit, robot going out of bounds etc.
An infinite-horizon environment is an obvious example where this is needed. We cannot wait forever for the episode to complete, so we set a practical time-limit after which we forcibly halt the episode. The last state in this case is not a terminal state since it has a non-zero transition probability of moving to another state as per the Markov Decision Process that defines the RL problem. This is also different from time-limits in finite horizon environments as the agent in this case has no idea about this time-limit.
### Importance in learning code
Bootstrapping (using one or more estimated values of a variable to update estimates of the same variable) is a key aspect of Reinforcement Learning. A value function will tell you how much discounted reward you will get from a particular state if you follow a given policy. When an episode stops at any given point, by looking at the value of the final state, the agent is able to estimate how much discounted reward could have been obtained if the episode has continued. This is an example of handling truncation.
More formally, a common example of bootstrapping in RL is updating the estimate of the Q-value function,
```math
Q_{target}(o_t, a_t) = r_t + \gamma . \max_a(Q(o_{t+1}, a_{t+1}))
```
In classical RL, the new `Q` estimate is a weighted average of previous `Q` estimate and `Q_target` while in Deep Q-Learning, the error between `Q_target` and previous `Q` estimate is minimized.
However, at the terminal state, bootstrapping is not done,
```math
Q_{target}(o_t, a_t) = r_t
```
This is where the distinction between termination and truncation becomes important. When an episode ends due to termination we don't bootstrap, when it ends due to truncation, we bootstrap.
While using gym environments, the `done` signal (default for < v0.26) is frequently used to determine whether to bootstrap or not. However this is incorrect since it does not differentiate between termination and truncation.
A simple example for value functions is shown below. This is an illustrative example and not part of any specific algorithm.
```python
# INCORRECT
vf_target = rew + gamma * (1-done)* vf_next_state
```
This is incorrect in the case of episode ending due to a truncation, where bootstrapping needs to happen but it doesn't.
### Solution
From v0.26 onwards, gym's `env.step` API returns both termination and truncation information explicitly. In previous version truncation information was supplied through the info key `TimeLimit.truncated`. The correct way to handle terminations and truncations now is,
```python
# terminated = done and 'TimeLimit.truncated' not in info # This was needed in previous versions.
vf_target = rew + gamma*(1-terminated)*vf_next_state
```

View File

@@ -23,7 +23,7 @@ The following example runs 3 copies of the ``CartPole-v1`` environment in parall
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset()
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, terminated, truncated, infos = envs.step(actions)
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
array([[ 0.00122802, 0.16228443, 0.02521779, -0.23700266],
@@ -32,7 +32,9 @@ array([[ 0.00122802, 0.16228443, 0.02521779, -0.23700266],
dtype=float32)
>>> rewards
array([1., 1., 1.])
>>> terminated
>>> termination
array([False, False, False])
>>> truncation
array([False, False, False])
>>> infos
{}
@@ -92,7 +94,7 @@ While standard Gymnasium environments take a single action and return a single o
dtype=float32), {})
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, terminated, truncated, infos = envs.step(actions)
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
array([[ 0.00187507, 0.18986781, -0.03168437, -0.301252 ],
@@ -101,7 +103,9 @@ array([[ 0.00187507, 0.18986781, -0.03168437, -0.301252 ],
dtype=float32)
>>> rewards
array([1., 1., 1.])
>>> terminated
>>> termination
array([False, False, False])
>>> truncation
array([False, False, False])
>>> infos
{}
@@ -140,7 +144,7 @@ Dict(fire:MultiDiscrete([2 2 2]), jump:MultiDiscrete([2 2 2]), acceleration:Box(
... "jump": np.array([0, 1, 0]),
... "acceleration": np.random.uniform(-1., 1., size=(3, 2))
... }
>>> observations, rewards, terminated, truncated, infos = envs.step(actions)
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
{"position": array([[-0.5337036 , 0.7439302 , 0.41748118],
[ 0.9373266 , -0.5780453 , 0.8987405 ],
@@ -150,19 +154,18 @@ Dict(fire:MultiDiscrete([2 2 2]), jump:MultiDiscrete([2 2 2]), acceleration:Box(
[ 0.26341468, 0.72282314]], dtype=float32)}
```
The environment copies inside a vectorized environment automatically call `gymnasium.Env.reset` at the end of an episode. In the following example, the episode of the 3rd copy ends after 2 steps (the agent fell in a hole), and the paralle environment gets reset (observation ``0``).
The environment copies inside a vectorized environment automatically call `gymnasium.Env.reset` at the end of an episode. In the following example, the episode of the 3rd copy ends after 2 steps (the agent fell in a hole), and the parallel environment gets reset (observation ``0``).
```python
>>> envs = gym.vector.make("FrozenLake-v1", num_envs=3, is_slippery=False)
>>> envs.reset()
(array([0, 0, 0]), {'prob': array([1, 1, 1]), '_prob': array([ True, True, True])})
>>> observations, rewards, terminated, truncated, infos = envs.step(np.array([1, 2, 2]))
>>> observations, rewards, terminated, truncated, infos = envs.step(np.array([1, 2, 1]))
>>> terminated
array([False, False, True])
>>> observations, rewards, termination, truncation, infos = envs.step(np.array([1, 2, 2]))
>>> observations, rewards, termination, truncation, infos = envs.step(np.array([1, 2, 1]))
>>> observations
array([8, 2, 0])
>>> termination
array([False, False, True])
```
Vectorized environments will return `infos` in the form of a dictionary where each value is an array of length `num_envs` and the _i-th_ value of the array represents the info of the _i-th_ environment.
@@ -175,16 +178,15 @@ If the _dtype_ of the returned info is whether `int`, `float`, `bool` or any _dt
>>> observations, infos = envs.reset()
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, terminated, truncated, infos = envs.step(actions)
>>> dones = np.logical_or(terminated, truncated)
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> while not any(dones):
... observations, rewards, terminated, truncated, infos = envs.step(actions)
>>> while not any(np.logical_or(termination, truncation)):
... observations, rewards, termination, truncation, infos = envs.step(actions)
>>> print(dones)
>>> termination
[False, True, False]
>>> print(infos)
>>> infos
{'final_observation': array([None,
array([-0.11350546, -1.8090094 , 0.23710881, 2.8017728 ], dtype=float32),
None], dtype=object), '_final_observation': array([False, True, False])}
@@ -238,7 +240,7 @@ This is convenient, for example, if you instantiate a policy. In the following e
... )
>>> observations, infos = envs.reset()
>>> actions = policy(weights, observations).argmax(axis=1)
>>> observations, rewards, terminated, truncated, infos = envs.step(actions)
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
```
## Intermediate Usage
@@ -279,7 +281,7 @@ Because sometimes things may not go as planned, the exceptions raised in any giv
>>> envs = gymnasium.vector.AsyncVectorEnv([lambda: ErrorEnv()] * 3)
>>> observations, infos = envs.reset()
>>> observations, rewards, terminated, truncated, infos = envs.step(np.array([0, 0, 1]))
>>> observations, rewards, termination, termination, infos = envs.step(np.array([0, 0, 1]))
ERROR: Received the following error from Worker-2: ValueError: An error occurred.
ERROR: Shutting down Worker-2.
ERROR: Raising the last exception back to the main process.
@@ -320,7 +322,7 @@ In the following example, we create a new environment `SMILESEnv`, whose observa
... shared_memory=False
... )
>>> envs.reset()
>>> observations, rewards, terminated, truncated, infos = envs.step(np.array([2, 5, 4]))
>>> observations, rewards, termination, truncation, infos = envs.step(np.array([2, 5, 4]))
>>> observations
('[(', '[O', '[C')
```

View File

@@ -64,6 +64,7 @@ environments/third_party_environments/index
content/environment_creation
content/vectorising
content/handling_timelimits
```
```{toctree}

View File

@@ -54,8 +54,8 @@ for env_spec in tqdm(gymnasium.envs.registry.values()):
frames = []
while True:
state, info = env.reset()
done = False
while not done and len(frames) <= LENGTH:
terminated, truncated = False, False
while not (terminated or truncated) and len(frames) <= LENGTH:
frame = env.render(mode="rgb_array")
repeat = (
@@ -66,7 +66,7 @@ for env_spec in tqdm(gymnasium.envs.registry.values()):
for i in range(repeat):
frames.append(Image.fromarray(frame))
action = env.action_space.sample()
state_next, reward, done, info = env.step(action)
state_next, reward, terminated, truncated, info = env.step(action)
if len(frames) > LENGTH:
break

View File

@@ -16,7 +16,7 @@ class StepAPICompatibility(gym.Wrapper):
Args:
env (gym.Env): the env to wrap. Can be in old or new API
apply_step_compatibility (bool): Apply to convert environment to use new step API that returns two bools. (False by default)
output_truncation_bool (bool): Apply to convert environment to use new step API that returns two bool. (True by default)
Examples:
>>> env = gym.make("CartPole-v1")
@@ -24,7 +24,7 @@ class StepAPICompatibility(gym.Wrapper):
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
>>> env = gym.make("CartPole-v1", apply_api_compatibility=True) # set to old API
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
>>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs
>>> env = StepAPICompatibility(CustomEnv(), output_truncation_bool=False) # manually using wrapper on unregistered envs
"""
@@ -43,7 +43,7 @@ class StepAPICompatibility(gym.Wrapper):
)
def step(self, action):
"""Steps through the environment, returning 5 or 4 items depending on `apply_step_compatibility`.
"""Steps through the environment, returning 5 or 4 items depending on `output_truncation_bool`.
Args:
action: action to step through the environment with