mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Add tutorial: qlearning with/without action masking for Taxi v3 env (#1345)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
BIN
docs/_static/img/tutorials/taxi_v3_action_masking_comparison.png
vendored
Normal file
BIN
docs/_static/img/tutorials/taxi_v3_action_masking_comparison.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 108 KiB |
297
docs/tutorials/training_agents/action_masking_taxi.py
Normal file
297
docs/tutorials/training_agents/action_masking_taxi.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Action Masking in the Taxi Environment
|
||||
======================================
|
||||
|
||||
This tutorial demonstrates how to use action masking in the Taxi environment to improve reinforcement learning performance by preventing invalid actions.
|
||||
|
||||
The Taxi environment is a classic grid world problem where a taxi needs to pick up a passenger and drop them off at their destination. In this environment, not all actions are valid at every state - for example, you can't drive through walls or pick up a passenger when you're not at their location.
|
||||
|
||||
Action masking is a technique that helps reinforcement learning agents avoid selecting invalid actions by providing a binary mask that indicates which actions are valid in the current state. This can significantly improve learning efficiency and performance.
|
||||
|
||||
Understanding the Taxi Environment
|
||||
----------------------------------
|
||||
|
||||
The Taxi environment has 6 possible actions:
|
||||
|
||||
- 0: Move south (down)
|
||||
- 1: Move north (up)
|
||||
- 2: Move east (right)
|
||||
- 3: Move west (left)
|
||||
- 4: Pickup passenger
|
||||
- 5: Drop off passenger
|
||||
|
||||
The environment provides an ``action_mask`` in the info dictionary returned by ``reset()`` and ``step()``.
|
||||
This mask is a binary array where 1 indicates a valid action and 0 indicates an invalid action.
|
||||
|
||||
How Action Masking Works
|
||||
------------------------
|
||||
|
||||
Action masking works by constraining the agent's action selection to only valid actions:
|
||||
|
||||
1. During exploration: When selecting random actions, we only choose from the set of valid actions
|
||||
2. During exploitation: When selecting the best action based on Q-values, we only consider Q-values for valid actions
|
||||
3. During Q-learning updates: We compute the maximum future Q-value only over valid actions in the next state
|
||||
|
||||
Let's implement this step by step:
|
||||
"""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
# Base random seed for reproducibility
|
||||
BASE_RANDOM_SEED = 58922320
|
||||
|
||||
|
||||
def train_q_learning(
|
||||
env,
|
||||
use_action_mask: bool = True,
|
||||
episodes: int = 5000,
|
||||
seed: int = BASE_RANDOM_SEED,
|
||||
learning_rate: float = 0.1,
|
||||
discount_factor: float = 0.95,
|
||||
epsilon: float = 0.1,
|
||||
) -> dict:
|
||||
"""Train a Q-learning agent with or without action masking."""
|
||||
# Set random seeds for reproducibility
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Initialize Q-table
|
||||
n_states = env.observation_space.n
|
||||
n_actions = env.action_space.n
|
||||
q_table = np.zeros((n_states, n_actions))
|
||||
|
||||
# Track episode rewards for analysis
|
||||
episode_rewards = []
|
||||
|
||||
for episode in range(episodes):
|
||||
# Reset environment
|
||||
state, info = env.reset(seed=seed + episode)
|
||||
total_reward = 0
|
||||
done = False
|
||||
truncated = False
|
||||
|
||||
while not (done or truncated):
|
||||
# Get action mask if using it
|
||||
action_mask = info["action_mask"] if use_action_mask else None
|
||||
|
||||
# Epsilon-greedy action selection with masking
|
||||
if np.random.random() < epsilon:
|
||||
# Random action selection
|
||||
if use_action_mask:
|
||||
# Only select from valid actions
|
||||
valid_actions = np.nonzero(action_mask == 1)[0]
|
||||
action = np.random.choice(valid_actions)
|
||||
else:
|
||||
# Select from all actions
|
||||
action = np.random.randint(0, n_actions)
|
||||
else:
|
||||
# Greedy action selection
|
||||
if use_action_mask:
|
||||
# Only consider valid actions for exploitation
|
||||
valid_actions = np.nonzero(action_mask == 1)[0]
|
||||
if len(valid_actions) > 0:
|
||||
action = valid_actions[np.argmax(q_table[state, valid_actions])]
|
||||
else:
|
||||
action = np.random.randint(0, n_actions)
|
||||
else:
|
||||
# Consider all actions
|
||||
action = np.argmax(q_table[state])
|
||||
|
||||
# Take action and observe result
|
||||
next_state, reward, done, truncated, info = env.step(action)
|
||||
total_reward += reward
|
||||
|
||||
# Q-learning update
|
||||
if not (done or truncated):
|
||||
if use_action_mask:
|
||||
# Only consider valid next actions for bootstrapping
|
||||
next_mask = info["action_mask"]
|
||||
valid_next_actions = np.nonzero(next_mask == 1)[0]
|
||||
if len(valid_next_actions) > 0:
|
||||
next_max = np.max(q_table[next_state, valid_next_actions])
|
||||
else:
|
||||
next_max = 0
|
||||
else:
|
||||
# Consider all next actions
|
||||
next_max = np.max(q_table[next_state])
|
||||
|
||||
# Update Q-value
|
||||
q_table[state, action] = q_table[state, action] + learning_rate * (
|
||||
reward + discount_factor * next_max - q_table[state, action]
|
||||
)
|
||||
|
||||
state = next_state
|
||||
|
||||
episode_rewards.append(total_reward)
|
||||
|
||||
return {
|
||||
"episode_rewards": episode_rewards,
|
||||
"mean_reward": np.mean(episode_rewards),
|
||||
"std_reward": np.std(episode_rewards),
|
||||
}
|
||||
|
||||
|
||||
# %%
|
||||
# Running the Experiment
|
||||
# ======================
|
||||
#
|
||||
# Now we'll run experiments to compare the performance of Q-learning agents with and without action masking.
|
||||
# We'll use multiple random seeds to ensure robust statistical comparison.
|
||||
#
|
||||
# The experiment setup:
|
||||
# - 12 independent runs with different random seeds
|
||||
# - 5000 episodes per run
|
||||
# - Standard Q-learning hyperparameters ``(α=0.1, γ=0.95, ε=0.1)``
|
||||
#
|
||||
|
||||
|
||||
# Experiment parameters
|
||||
n_runs = 12
|
||||
episodes = 5000
|
||||
learning_rate = 0.1
|
||||
discount_factor = 0.95
|
||||
epsilon = 0.1
|
||||
|
||||
# Generate different seeds for each run
|
||||
seeds = [BASE_RANDOM_SEED + i for i in range(n_runs)]
|
||||
|
||||
# Store results for comparison
|
||||
masked_results_list = []
|
||||
unmasked_results_list = []
|
||||
|
||||
# Run experiments with different seeds
|
||||
for i, seed in enumerate(seeds):
|
||||
print(f"Run {i + 1}/{n_runs} with seed {seed}")
|
||||
|
||||
# Train agent WITH action masking
|
||||
env_masked = gym.make("Taxi-v3")
|
||||
masked_results = train_q_learning(
|
||||
env_masked,
|
||||
use_action_mask=True,
|
||||
seed=seed,
|
||||
learning_rate=learning_rate,
|
||||
discount_factor=discount_factor,
|
||||
epsilon=epsilon,
|
||||
episodes=episodes,
|
||||
)
|
||||
env_masked.close()
|
||||
masked_results_list.append(masked_results)
|
||||
|
||||
# Train agent WITHOUT action masking
|
||||
env_unmasked = gym.make("Taxi-v3")
|
||||
unmasked_results = train_q_learning(
|
||||
env_unmasked,
|
||||
use_action_mask=False,
|
||||
seed=seed,
|
||||
learning_rate=learning_rate,
|
||||
discount_factor=discount_factor,
|
||||
epsilon=epsilon,
|
||||
episodes=episodes,
|
||||
)
|
||||
env_unmasked.close()
|
||||
unmasked_results_list.append(unmasked_results)
|
||||
|
||||
|
||||
# %%
|
||||
# Visualizing Results
|
||||
# ===================
|
||||
#
|
||||
# After running all experiments, we calculate statistics and create visualizations to compare the performance of both approaches.
|
||||
#
|
||||
|
||||
# Calculate statistics across runs
|
||||
masked_mean_rewards = [r["mean_reward"] for r in masked_results_list]
|
||||
unmasked_mean_rewards = [r["mean_reward"] for r in unmasked_results_list]
|
||||
|
||||
masked_overall_mean = np.mean(masked_mean_rewards)
|
||||
masked_overall_std = np.std(masked_mean_rewards)
|
||||
unmasked_overall_mean = np.mean(unmasked_mean_rewards)
|
||||
unmasked_overall_std = np.std(unmasked_mean_rewards)
|
||||
|
||||
# Create visualization
|
||||
plt.figure(figsize=(12, 8), dpi=100)
|
||||
|
||||
# Plot individual runs with low alpha
|
||||
for i, (masked_results, unmasked_results) in enumerate(
|
||||
zip(masked_results_list, unmasked_results_list)
|
||||
):
|
||||
plt.plot(
|
||||
masked_results["episode_rewards"],
|
||||
label="With Action Masking" if i == 0 else None,
|
||||
color="blue",
|
||||
alpha=0.1,
|
||||
)
|
||||
plt.plot(
|
||||
unmasked_results["episode_rewards"],
|
||||
label="Without Action Masking" if i == 0 else None,
|
||||
color="red",
|
||||
alpha=0.1,
|
||||
)
|
||||
|
||||
# Calculate and plot mean curves across all runs
|
||||
masked_mean_curve = np.mean([r["episode_rewards"] for r in masked_results_list], axis=0)
|
||||
unmasked_mean_curve = np.mean(
|
||||
[r["episode_rewards"] for r in unmasked_results_list], axis=0
|
||||
)
|
||||
|
||||
plt.plot(
|
||||
masked_mean_curve, label="With Action Masking (Mean)", color="blue", linewidth=2
|
||||
)
|
||||
plt.plot(
|
||||
unmasked_mean_curve,
|
||||
label="Without Action Masking (Mean)",
|
||||
color="red",
|
||||
linewidth=2,
|
||||
)
|
||||
|
||||
plt.xlabel("Episode")
|
||||
plt.ylabel("Total Reward")
|
||||
plt.title("Training Performance: Q-Learning with vs without Action Masking")
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# Save the figure
|
||||
savefig_folder = Path("_static/img/tutorials/")
|
||||
savefig_folder.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(
|
||||
savefig_folder / "taxi_v3_action_masking_comparison.png",
|
||||
bbox_inches="tight",
|
||||
dpi=150,
|
||||
)
|
||||
plt.show()
|
||||
|
||||
|
||||
# %%
|
||||
# .. image:: /_static/img/tutorials/taxi_v3_action_masking_comparison.png
|
||||
#
|
||||
# Results Analysis
|
||||
# ================
|
||||
#
|
||||
# The comparison demonstrates several important benefits of using action masking:
|
||||
#
|
||||
# Key Benefits of Action Masking:
|
||||
#
|
||||
# 1. Faster Convergence: Agents with action masking typically learn faster because they don't waste time exploring
|
||||
# invalid actions.
|
||||
#
|
||||
# 2. Better Performance: By focusing only on valid actions, the agent can achieve higher rewards more consistently.
|
||||
#
|
||||
# 3. More Stable Learning: Action masking reduces the variance in learning by eliminating the randomness associated with invalid action selection.
|
||||
#
|
||||
# 4. Practical Applicability: In real-world scenarios, action masking prevents agents from taking actions that could be dangerous or impossible.
|
||||
#
|
||||
# Reminder of Key Implementation Details
|
||||
#
|
||||
# - Action Selection: We filter available actions using ``np.nonzero(action_mask == 1)[0]`` to get only valid actions
|
||||
# - Q-Value Updates: When computing the maximum future Q-value, we only consider valid actions in the next state
|
||||
# - Exploration: Random action selection is constrained to the set of valid actions
|
||||
#
|
||||
# This approach ensures that the agent never selects invalid actions while still maintaining the exploration-exploitation balance necessary for effective learning.
|
||||
#
|
Reference in New Issue
Block a user