Randomize LunarLander wind generation at reset to gain statistical independence between episodes (#959)

This commit is contained in:
TobiasKallehauge
2024-03-09 10:42:08 +01:00
committed by GitHub
parent d684778e9d
commit fd4ae52045
8 changed files with 70 additions and 36 deletions

View File

@@ -23,7 +23,7 @@ An API standard for reinforcement learning with a diverse collection of referenc
import gymnasium as gym
# Initialise the environment
env = gym.make("LunarLander-v2", render_mode="human")
env = gym.make("LunarLander-v3", render_mode="human")
# Reset the environment to generate the first observation
observation, info = env.reset(seed=42)

View File

@@ -54,7 +54,7 @@ For gymnasium, the "agent-environment-loop" is implemented below for a single ep
```python
import gymnasium as gym
env = gym.make("LunarLander-v2", render_mode="human")
env = gym.make("LunarLander-v3", render_mode="human")
observation, info = env.reset()
episode_over = False

View File

@@ -15,7 +15,7 @@ Gymnasium is a fork of `OpenAI Gym v0.26 <https://github.com/openai/gym/releases
```python
import gym
env = gym.make("LunarLander-v2", options={})
env = gym.make("LunarLander-v3", options={})
env.seed(123)
observation = env.reset()
@@ -33,7 +33,7 @@ env.close()
```python
import gym
env = gym.make("LunarLander-v2", render_mode="human")
env = gym.make("LunarLander-v3", render_mode="human")
observation, info = env.reset(seed=123, options={})
done = False

View File

@@ -267,7 +267,7 @@ class A2C(nn.Module):
# The simplest way to create vector environments is by calling `gym.vector.make`, which creates multiple instances of the same environment:
#
envs = gym.vector.make("LunarLander-v2", num_envs=3, max_episode_steps=600)
envs = gym.vector.make("LunarLander-v3", num_envs=3, max_episode_steps=600)
# %%
@@ -277,13 +277,13 @@ envs = gym.vector.make("LunarLander-v2", num_envs=3, max_episode_steps=600)
# If we want to randomize the environment for training to get more robust agents (that can deal with different parameterizations of an environment
# and theirfore might have a higher degree of generalization), we can set the desired parameters manually or use a pseudo-random number generator to generate them.
#
# Manually setting up 3 parallel 'LunarLander-v2' envs with different parameters:
# Manually setting up 3 parallel 'LunarLander-v3' envs with different parameters:
envs = gym.vector.AsyncVectorEnv(
[
lambda: gym.make(
"LunarLander-v2",
"LunarLander-v3",
gravity=-10.0,
enable_wind=True,
wind_power=15.0,
@@ -291,7 +291,7 @@ envs = gym.vector.AsyncVectorEnv(
max_episode_steps=600,
),
lambda: gym.make(
"LunarLander-v2",
"LunarLander-v3",
gravity=-9.8,
enable_wind=True,
wind_power=10.0,
@@ -299,7 +299,7 @@ envs = gym.vector.AsyncVectorEnv(
max_episode_steps=600,
),
lambda: gym.make(
"LunarLander-v2", gravity=-7.0, enable_wind=False, max_episode_steps=600
"LunarLander-v3", gravity=-7.0, enable_wind=False, max_episode_steps=600
),
]
)
@@ -309,14 +309,14 @@ envs = gym.vector.AsyncVectorEnv(
#
# ------------------------------
#
# Randomly generating the parameters for 3 parallel 'LunarLander-v2' envs, using `np.clip` to stay in the recommended parameter space:
# Randomly generating the parameters for 3 parallel 'LunarLander-v3' envs, using `np.clip` to stay in the recommended parameter space:
#
envs = gym.vector.AsyncVectorEnv(
[
lambda: gym.make(
"LunarLander-v2",
"LunarLander-v3",
gravity=np.clip(
np.random.normal(loc=-10.0, scale=1.0), a_min=-11.99, a_max=-0.01
),
@@ -374,7 +374,7 @@ if randomize_domain:
envs = gym.vector.AsyncVectorEnv(
[
lambda: gym.make(
"LunarLander-v2",
"LunarLander-v3",
gravity=np.clip(
np.random.normal(loc=-10.0, scale=1.0), a_min=-11.99, a_max=-0.01
),
@@ -392,7 +392,7 @@ if randomize_domain:
)
else:
envs = gym.vector.make("LunarLander-v2", num_envs=n_envs, max_episode_steps=600)
envs = gym.vector.make("LunarLander-v3", num_envs=n_envs, max_episode_steps=600)
obs_shape = envs.single_observation_space.shape[0]
@@ -499,7 +499,7 @@ for sample_phase in tqdm(range(n_updates)):
rolling_length = 20
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 5))
fig.suptitle(
f"Training plots for {agent.__class__.__name__} in the LunarLander-v2 environment \n \
f"Training plots for {agent.__class__.__name__} in the LunarLander-v3 environment \n \
(n_envs={n_envs}, n_steps_per_update={n_steps_per_update}, randomize_domain={randomize_domain})"
)
@@ -606,7 +606,7 @@ plt.show()
# because the gradients of the environments are good enough after a relatively low number of environments
# (especially if the environment is not very complex). In this case, increasing the number of environments
# does not increase the learning speed, and actually increases the runtime, possibly due to the additional time
# needed to calculate the gradients. For LunarLander-v2, the best performing configuration used a AsyncVectorEnv
# needed to calculate the gradients. For LunarLander-v3, the best performing configuration used a AsyncVectorEnv
# with 10 parallel environments, but environments with a higher complexity may require more
# parallel environments to achieve optimal performance.
#
@@ -662,7 +662,7 @@ for episode in range(n_showcase_episodes):
# create a new sample environment to get new random parameters
if randomize_domain:
env = gym.make(
"LunarLander-v2",
"LunarLander-v3",
render_mode="human",
gravity=np.clip(
np.random.normal(loc=-10.0, scale=2.0), a_min=-11.99, a_max=-0.01
@@ -677,7 +677,7 @@ for episode in range(n_showcase_episodes):
max_episode_steps=500,
)
else:
env = gym.make("LunarLander-v2", render_mode="human", max_episode_steps=500)
env = gym.make("LunarLander-v3", render_mode="human", max_episode_steps=500)
# get an initial state
state, info = env.reset()
@@ -705,7 +705,7 @@ env.close()
# from gymnasium.utils.play import play
#
# play(gym.make('LunarLander-v2', render_mode='rgb_array'),
# play(gym.make('LunarLander-v3', render_mode='rgb_array'),
# keys_to_action={'w': 2, 'a': 1, 'd': 3}, noop=0)

View File

@@ -81,14 +81,14 @@ register(
# ----------------------------------------
register(
id="LunarLander-v2",
id="LunarLander-v3",
entry_point="gymnasium.envs.box2d.lunar_lander:LunarLander",
max_episode_steps=1000,
reward_threshold=200,
)
register(
id="LunarLanderContinuous-v2",
id="LunarLanderContinuous-v3",
entry_point="gymnasium.envs.box2d.lunar_lander:LunarLander",
kwargs={"continuous": True},
max_episode_steps=1000,

View File

@@ -97,7 +97,7 @@ class LunarLander(gym.Env, EzPickle):
python gymnasium/envs/box2d/lunar_lander.py
```
<!-- To play yourself, run: -->
<!-- python examples/agents/keyboard_agent.py LunarLander-v2 -->
<!-- python examples/agents/keyboard_agent.py LunarLander-v3 -->
## Action Space
There are four discrete actions available:
@@ -150,10 +150,10 @@ class LunarLander(gym.Env, EzPickle):
```python
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", continuous=False, gravity=-10.0,
>>> env = gym.make("LunarLander-v3", continuous=False, gravity=-10.0,
... enable_wind=False, wind_power=15.0, turbulence_power=1.5)
>>> env
<TimeLimit<OrderEnforcing<PassiveEnvChecker<LunarLander<LunarLander-v2>>>>>
<TimeLimit<OrderEnforcing<PassiveEnvChecker<LunarLander<LunarLander-v3>>>>>
```
@@ -179,6 +179,7 @@ class LunarLander(gym.Env, EzPickle):
The recommended value for `turbulence_power` is between 0.0 and 2.0.
## Version History
- v3: Reset wind and turbulence offset (`C`) whenever the environment is reset to ensure statistical independence between consecutive episodes (related [GitHub issue](https://github.com/Farama-Foundation/Gymnasium/issues/954)).
- v2: Count energy spent and in v0.24, added turbulence with wind power and turbulence_power parameters
- v1: Legs contact with ground added in state vector; contact with ground give +10 reward points,
and -10 if then lose contact; reward renormalized to 200; harder initial random push.
@@ -254,8 +255,6 @@ class LunarLander(gym.Env, EzPickle):
self.turbulence_power = turbulence_power
self.enable_wind = enable_wind
self.wind_idx = np.random.randint(-9999, 9999)
self.torque_idx = np.random.randint(-9999, 9999)
self.screen: pygame.Surface = None
self.clock = None
@@ -403,6 +402,10 @@ class LunarLander(gym.Env, EzPickle):
True,
)
if self.enable_wind: # Initialize wind pattern based on index
self.wind_idx = self.np_random.integers(-9999, 9999)
self.torque_idx = self.np_random.integers(-9999, 9999)
# Create Lander Legs
self.legs = []
for i in [-1, +1]:
@@ -872,10 +875,10 @@ class LunarLanderContinuous:
"Error initializing LunarLanderContinuous Environment.\n"
"Currently, we do not support initializing this mode of environment by calling the class directly.\n"
"To use this environment, instead create it by specifying the continuous keyword in gym.make, i.e.\n"
'gym.make("LunarLander-v2", continuous=True)'
'gym.make("LunarLander-v3", continuous=True)'
)
if __name__ == "__main__":
env = gym.make("LunarLander-v2", render_mode="rgb_array")
env = gym.make("LunarLander-v3", render_mode="rgb_array")
demo_heuristic_lander(env, render=True)

View File

@@ -35,7 +35,7 @@ class RenderCollection(
Example:
Return the list of frames for the number of steps ``render`` wasn't called.
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
>>> env = RenderCollection(env)
>>> _ = env.reset(seed=123)
>>> for _ in range(5):
@@ -51,7 +51,7 @@ class RenderCollection(
Return the list of frames for the number of steps the episode was running.
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
>>> env = RenderCollection(env, pop_frames=False)
>>> _ = env.reset(seed=123)
>>> for _ in range(5):
@@ -67,7 +67,7 @@ class RenderCollection(
Collect all frames for all episodes, without clearing them when render is called
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
>>> env = RenderCollection(env, pop_frames=False, reset_clean=False)
>>> _ = env.reset(seed=123)
>>> for _ in range(5):
@@ -177,7 +177,7 @@ class RecordVideo(
Examples - Run the environment for 50 episodes, and save the video every 10 episodes starting from the 0th:
>>> import os
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
>>> trigger = lambda t: t % 10 == 0
>>> env = RecordVideo(env, video_folder="./save_videos1", episode_trigger=trigger, disable_logger=True)
>>> for i in range(50):
@@ -193,7 +193,7 @@ class RecordVideo(
Examples - Run the environment for 5 episodes, start a recording every 200th step, making sure each video is 100 frames long:
>>> import os
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
>>> trigger = lambda t: t % 200 == 0
>>> env = RecordVideo(env, video_folder="./save_videos2", step_trigger=trigger, video_length=100, disable_logger=True)
>>> for i in range(5):
@@ -210,7 +210,7 @@ class RecordVideo(
Examples - Run 3 episodes, record everything, but in chunks of 1000 frames:
>>> import os
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
>>> env = RecordVideo(env, video_folder="./save_videos3", video_length=1000, disable_logger=True)
>>> for i in range(3):
... termination, truncation = False, False
@@ -432,7 +432,7 @@ class HumanRendering(
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import HumanRendering
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
>>> wrapped = HumanRendering(env)
>>> obs, _ = wrapped.reset() # This will start rendering to the screen
@@ -446,7 +446,7 @@ class HumanRendering(
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
will always return an empty list:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
>>> env = gym.make("LunarLander-v3", render_mode="rgb_array_list")
>>> wrapped = HumanRendering(env)
>>> obs, _ = wrapped.reset()
>>> env.render() # env.render() will always return an empty list!

View File

@@ -12,11 +12,42 @@ from gymnasium.envs.toy_text.frozen_lake import generate_random_map
def test_lunar_lander_heuristics():
"""Tests the LunarLander environment by checking if the heuristic lander works."""
lunar_lander = gym.make("LunarLander-v2", disable_env_checker=True)
lunar_lander = gym.make("LunarLander-v3", disable_env_checker=True)
total_reward = demo_heuristic_lander(lunar_lander, seed=1)
assert total_reward > 100
@pytest.mark.parametrize("seed", [0, 10, 20, 30, 40])
def test_lunar_lander_random_wind_seed(seed: int):
"""Test that the wind_idx and torque are correctly drawn when setting a seed"""
lunar_lander = gym.make(
"LunarLander-v3", disable_env_checker=True, enable_wind=True
).unwrapped
lunar_lander.reset(seed=seed)
# Test that same seed gives same wind
w1, t1 = lunar_lander.wind_idx, lunar_lander.torque_idx
lunar_lander.reset(seed=seed)
w2, t2 = lunar_lander.wind_idx, lunar_lander.torque_idx
assert (
w1 == w2 and t1 == t2
), "Setting same seed caused different initial wind or torque index"
# Test that different seed gives different wind
# There is a small chance that different seeds causes same number so test
# 10 times (with different seeds) to make this chance incredibly tiny.
for i in range(1, 11):
lunar_lander.reset(seed=seed + i)
w3, t3 = lunar_lander.wind_idx, lunar_lander.torque_idx
if w2 != w3 and t1 != t3: # Found different initial values
break
else: # no break
raise AssertionError(
"Setting different seed caused same initial wind or torque index"
)
def test_carracing_domain_randomize():
"""Tests the CarRacing Environment domain randomization.