mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Incorporated a few spelling and logical fixes and added training logs (#109)
This commit is contained in:
BIN
docs/_static/img/tutorials/blackjack_training_plots.png
vendored
Normal file
BIN
docs/_static/img/tutorials/blackjack_training_plots.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 70 KiB |
Binary file not shown.
Before Width: | Height: | Size: 122 KiB After Width: | Height: | Size: 102 KiB |
Binary file not shown.
Before Width: | Height: | Size: 126 KiB After Width: | Height: | Size: 107 KiB |
@@ -4,7 +4,6 @@ Solving Blackjack with Q-Learning
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# %%
|
||||
# .. image:: /_static/img/tutorials/blackjack_AE_loop.jpg
|
||||
# :width: 650
|
||||
@@ -17,10 +16,15 @@ Solving Blackjack with Q-Learning
|
||||
# infamous for being beatable under certain conditions. This version of
|
||||
# the game uses an infinite deck (we draw the cards with replacement), so
|
||||
# counting cards won’t be a viable strategy in our simulated game.
|
||||
# Full documentation can be found at https://gymnasium.farama.org/environments/toy_text/blackjack
|
||||
#
|
||||
# **Objective**: To win, your card sum should be greater than than the
|
||||
# **Objective**: To win, your card sum should be greater than the
|
||||
# dealers without exceeding 21.
|
||||
#
|
||||
# **Actions**: Agents can pick between two actions:
|
||||
# - stand (0): the player takes no more cards
|
||||
# - hit (1): the player will be given another card, however the player could get over 21 and bust
|
||||
#
|
||||
# **Approach**: To solve this environment by yourself, you can pick your
|
||||
# favorite discrete RL algorithm. The presented solution uses *Q-learning*
|
||||
# (a model-free RL algorithm).
|
||||
@@ -35,13 +39,17 @@ Solving Blackjack with Q-Learning
|
||||
# Author: Till Zemann
|
||||
# License: MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import gym
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.patches import Patch
|
||||
from tqdm import tqdm
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
# Let's start by creating the blackjack environment.
|
||||
# Note: We are going to follow the rules from Sutton & Barto.
|
||||
@@ -49,15 +57,16 @@ from matplotlib.patches import Patch
|
||||
|
||||
env = gym.make("Blackjack-v1", sab=True)
|
||||
|
||||
|
||||
# %%
|
||||
# .. code:: py
|
||||
#
|
||||
# # Other possible environment configurations:
|
||||
# # Other possible environment configurations are:
|
||||
#
|
||||
# env = gym.make('Blackjack-v1', natural=True, sab=False)``
|
||||
# env = gym.make('Blackjack-v1', natural=True, sab=False)
|
||||
# # Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).
|
||||
#
|
||||
# env = gym.make('Blackjack-v1', natural=False, sab=False)``
|
||||
# env = gym.make('Blackjack-v1', natural=False, sab=False)
|
||||
# # Whether to follow the exact rules outlined in the book by Sutton and Barto. If `sab` is `True`, the keyword argument `natural` will be ignored.
|
||||
#
|
||||
|
||||
|
||||
@@ -68,21 +77,18 @@ env = gym.make("Blackjack-v1", sab=True)
|
||||
# First of all, we call ``env.reset()`` to start an episode. This function
|
||||
# resets the environment to a starting position and returns an initial
|
||||
# ``observation``. We usually also set ``done = False``. This variable
|
||||
# will be useful later to check if a game is terminated. In this tutorial
|
||||
# we will use the terms observation and state synonymously but in more
|
||||
# complex problems a state might differ from the observation it is based
|
||||
# on.
|
||||
# will be useful later to check if a game is terminated (i.e., the player wins or loses).
|
||||
#
|
||||
|
||||
# reset the environment to get the first observation
|
||||
done = False
|
||||
observation, info = env.reset()
|
||||
|
||||
print(observation)
|
||||
# observation = (16, 9, False)
|
||||
|
||||
|
||||
# %%
|
||||
# Note that our observation is a 3-tuple consisting of 3 discrete values:
|
||||
# Note that our observation is a 3-tuple consisting of 3 values:
|
||||
#
|
||||
# - The players current sum
|
||||
# - Value of the dealers face-up card
|
||||
@@ -106,13 +112,13 @@ print(observation)
|
||||
# - ``reward``: This is the reward that the agent will receive after
|
||||
# taking the action.
|
||||
# - ``terminated``: This is a boolean variable that indicates whether or
|
||||
# not the episode is over.
|
||||
# not the environment has terminated.
|
||||
# - ``truncated``: This is a boolean variable that also indicates whether
|
||||
# the episode ended by early truncation.
|
||||
# the episode ended by early truncation, i.e., a time limit is reached.
|
||||
# - ``info``: This is a dictionary that might contain additional
|
||||
# information about the environment.
|
||||
#
|
||||
# The ``next_state``, ``reward``, and ``done`` variables are
|
||||
# The ``next_state``, ``reward``, ``terminated`` and ``truncated`` variables are
|
||||
# self-explanatory, but the ``info`` variable requires some additional
|
||||
# explanation. This variable contains a dictionary that might have some
|
||||
# extra information about the environment, but in the Blackjack-v1
|
||||
@@ -120,34 +126,29 @@ print(observation)
|
||||
# info dictionary has a ``ale.lives`` key that tells us how many lives the
|
||||
# agent has left. If the agent has 0 lives, then the episode is over.
|
||||
#
|
||||
# Blackjack-v1 doesn’t have a ``env.render()`` function to render the
|
||||
# environment, but in other environments you can use this function to
|
||||
# watch the agent play. Important to note is that using ``env.render()``
|
||||
# is optional - the environment is going to work even if you don’t render
|
||||
# it, but it can be helpful to see an episode rendered out to get an idea
|
||||
# of how the current policy behaves. Note that it is not a good idea to
|
||||
# call this function in your training loop because rendering slows down
|
||||
# training by a lot. Rather try to build an extra loop to evaluate and
|
||||
# showcase the agent after training.
|
||||
# Note that it is not a good idea to call ``env.render()`` in your training
|
||||
# loop because rendering slows down training by a lot. Rather try to build
|
||||
# an extra loop to evaluate and showcase the agent after training.
|
||||
#
|
||||
|
||||
# sample a random action from all valid actions
|
||||
action = env.action_space.sample()
|
||||
# action=1
|
||||
|
||||
# execute the action in our environment and receive infos from the environment
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
print("observation:", observation)
|
||||
print("reward:", reward)
|
||||
print("terminated:", terminated)
|
||||
print("truncated:", truncated)
|
||||
print("info:", info)
|
||||
# observation=(24, 10, False)
|
||||
# reward=-1.0
|
||||
# terminated=True
|
||||
# truncated=False
|
||||
# info={}
|
||||
|
||||
|
||||
# %%
|
||||
# Once ``terminated = True`` or ``truncated=True``, we should stop the
|
||||
# current episode and begin a new one with ``env.reset()``. If you
|
||||
# continue executing act`ons without resetting the environment, it still
|
||||
# continue executing actions without resetting the environment, it still
|
||||
# responds but the output won’t be useful for training (it might even be
|
||||
# harmful if the agent learns on invalid data).
|
||||
#
|
||||
@@ -159,7 +160,7 @@ print("info:", info)
|
||||
#
|
||||
# Let’s build a ``Q-learning agent`` to solve *Blackjack-v1*! We’ll need
|
||||
# some functions for picking an action and updating the agents action
|
||||
# values. To ensure that the agents expores the environment, one possible
|
||||
# values. To ensure that the agents explores the environment, one possible
|
||||
# solution is the ``epsilon-greedy`` strategy, where we pick a random
|
||||
# action with the percentage ``epsilon`` and the greedy action (currently
|
||||
# valued as the best) ``1 - epsilon``.
|
||||
@@ -167,43 +168,69 @@ print("info:", info)
|
||||
|
||||
|
||||
class BlackjackAgent:
|
||||
def __init__(self, lr=1e-3, epsilon=0.1, epsilon_decay=1e-4):
|
||||
"""
|
||||
Initialize an Reinforcement Learning agent with an empty dictionary
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
initial_epsilon: float,
|
||||
epsilon_decay: float,
|
||||
final_epsilon: float,
|
||||
discount_factor: float = 0.95,
|
||||
):
|
||||
"""Initialize a Reinforcement Learning agent with an empty dictionary
|
||||
of state-action values (q_values), a learning rate and an epsilon.
|
||||
"""
|
||||
self.q_values = defaultdict(
|
||||
lambda: np.zeros(env.action_space.n)
|
||||
) # maps a state to action values
|
||||
self.lr = lr
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_decay = epsilon_decay
|
||||
|
||||
def get_action(self, state):
|
||||
Args:
|
||||
learning_rate: The learning rate
|
||||
initial_epsilon: The initial epsilon value
|
||||
epsilon_decay: The decay for epsilon
|
||||
final_epsilon: The final epsilon value
|
||||
discount_factor: The discount factor for computing the Q-value
|
||||
"""
|
||||
self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))
|
||||
|
||||
self.lr = learning_rate
|
||||
self.discount_factor = discount_factor
|
||||
|
||||
self.epsilon = initial_epsilon
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.final_epsilon = final_epsilon
|
||||
|
||||
self.training_error = []
|
||||
|
||||
def get_action(self, obs: tuple[int, int, bool]) -> int:
|
||||
"""
|
||||
Returns the best action with probability (1 - epsilon)
|
||||
and a random action with probability epsilon to ensure exploration.
|
||||
otherwise a random action with probability epsilon to ensure exploration.
|
||||
"""
|
||||
# with probability epsilon return a random action to explore the environment
|
||||
if np.random.random() < self.epsilon:
|
||||
action = env.action_space.sample()
|
||||
return env.action_space.sample()
|
||||
|
||||
# with probability (1 - epsilon) act greedily (exploit)
|
||||
else:
|
||||
action = np.argmax(self.q_values[state])
|
||||
return action
|
||||
return int(np.argmax(self.q_values[obs]))
|
||||
|
||||
def update(self, state, action, reward, next_state, done):
|
||||
"""
|
||||
Updates the Q-value of an action.
|
||||
"""
|
||||
old_q_value = self.q_values[state][action]
|
||||
max_future_q = np.max(self.q_values[next_state])
|
||||
target = reward + self.lr * max_future_q * (1 - done)
|
||||
self.q_values[state][action] = (1 - self.lr) * old_q_value + self.lr * target
|
||||
def update(
|
||||
self,
|
||||
obs: tuple[int, int, bool],
|
||||
action: int,
|
||||
reward: float,
|
||||
terminated: bool,
|
||||
next_obs: tuple[int, int, bool],
|
||||
):
|
||||
"""Updates the Q-value of an action."""
|
||||
future_q_value = (not terminated) * np.max(self.q_values[next_obs])
|
||||
temporal_difference = (
|
||||
reward + self.discount_factor * future_q_value - self.q_values[obs][action]
|
||||
)
|
||||
|
||||
self.q_values[obs][action] = (
|
||||
self.q_values[obs][action] + self.lr * temporal_difference
|
||||
)
|
||||
self.training_error.append(temporal_difference)
|
||||
|
||||
def decay_epsilon(self):
|
||||
self.epsilon = self.epsilon - epsilon_decay
|
||||
self.epsilon = max(self.final_epsilon, self.epsilon - epsilon_decay)
|
||||
|
||||
|
||||
# %%
|
||||
@@ -216,89 +243,144 @@ class BlackjackAgent:
|
||||
#
|
||||
|
||||
# hyperparameters
|
||||
learning_rate = 1e-3
|
||||
start_epsilon = 0.8
|
||||
n_episodes = 200_000
|
||||
epsilon_decay = start_epsilon / n_episodes # less exploration over time
|
||||
learning_rate = 0.01
|
||||
n_episodes = 100_000
|
||||
start_epsilon = 1.0
|
||||
epsilon_decay = start_epsilon / (n_episodes / 2) # reduce the exploration over time
|
||||
final_epsilon = 0.1
|
||||
|
||||
agent = BlackjackAgent(
|
||||
lr=learning_rate, epsilon=start_epsilon, epsilon_decay=epsilon_decay
|
||||
learning_rate=learning_rate,
|
||||
initial_epsilon=start_epsilon,
|
||||
epsilon_decay=epsilon_decay,
|
||||
final_epsilon=final_epsilon,
|
||||
)
|
||||
|
||||
|
||||
def train(agent, n_episodes):
|
||||
for episode in range(n_episodes):
|
||||
|
||||
# reset the environment
|
||||
state, info = env.reset()
|
||||
done = False
|
||||
|
||||
# play one episode
|
||||
while not done:
|
||||
action = agent.get_action(observation)
|
||||
next_state, reward, terminated, truncated, info = env.step(action)
|
||||
done = (
|
||||
terminated or truncated
|
||||
) # if the episode terminated or was truncated early, set done to True
|
||||
agent.update(state, action, reward, next_state, done)
|
||||
state = next_state
|
||||
|
||||
agent.update(state, action, reward, next_state, done)
|
||||
|
||||
|
||||
# %%
|
||||
# Great, let’s train!
|
||||
#
|
||||
|
||||
train(agent, n_episodes)
|
||||
|
||||
|
||||
# %%
|
||||
# Visualizing the results
|
||||
# ------------------------------
|
||||
# Info: The current hyperparameters are set to quickly train a decent agent.
|
||||
# If you want to converge to the optimal policy, try increasing
|
||||
# the n_episodes by 10x and lower the learning_rate (e.g. to 0.001).
|
||||
#
|
||||
|
||||
|
||||
def create_grids(agent, usable_ace=False):
|
||||
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
|
||||
for episode in tqdm(range(n_episodes)):
|
||||
obs, info = env.reset()
|
||||
done = False
|
||||
|
||||
# play one episode
|
||||
while not done:
|
||||
action = agent.get_action(obs)
|
||||
next_obs, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
# update the agent
|
||||
agent.update(obs, action, reward, terminated, next_obs)
|
||||
|
||||
# update if the environment is done and the current obs
|
||||
done = terminated or truncated
|
||||
obs = next_obs
|
||||
|
||||
agent.decay_epsilon()
|
||||
|
||||
|
||||
# %%
|
||||
# Visualizing the training
|
||||
# ------------------------------
|
||||
#
|
||||
|
||||
rolling_length = 500
|
||||
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
|
||||
axs[0].set_title("Episode rewards")
|
||||
reward_moving_average = (
|
||||
np.convolve(
|
||||
np.array(env.return_queue).flatten(), np.ones(rolling_length), mode="valid"
|
||||
)
|
||||
/ rolling_length
|
||||
)
|
||||
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
|
||||
axs[1].set_title("Episode lengths")
|
||||
length_moving_average = (
|
||||
np.convolve(
|
||||
np.array(env.length_queue).flatten(), np.ones(rolling_length), mode="same"
|
||||
)
|
||||
/ rolling_length
|
||||
)
|
||||
axs[1].plot(range(len(length_moving_average)), length_moving_average)
|
||||
axs[2].set_title("Training Error")
|
||||
training_error_moving_average = (
|
||||
np.convolve(np.array(agent.training_error), np.ones(rolling_length), mode="same")
|
||||
/ rolling_length
|
||||
)
|
||||
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# %%
|
||||
# .. image:: /_static/img/tutorials/blackjack_training_plots.png
|
||||
#
|
||||
|
||||
|
||||
# %%
|
||||
# Visualising the policy
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def create_grids(agent, usable_ace=False):
|
||||
"""Create value and policy grid given an agent."""
|
||||
# convert our state-action values to state values
|
||||
# and build a policy dictionary that maps observations to actions
|
||||
V = defaultdict(float)
|
||||
state_value = defaultdict(float)
|
||||
policy = defaultdict(int)
|
||||
for obs, action_values in agent.q_values.items():
|
||||
V[obs] = np.max(action_values)
|
||||
policy[obs] = np.argmax(action_values)
|
||||
state_value[obs] = float(np.max(action_values))
|
||||
policy[obs] = int(np.argmax(action_values))
|
||||
|
||||
X, Y = np.meshgrid(
|
||||
np.arange(12, 22), np.arange(1, 11) # players count
|
||||
) # dealers face-up card
|
||||
player_count, dealer_count = np.meshgrid(
|
||||
# players count, dealers face-up card
|
||||
np.arange(12, 22),
|
||||
np.arange(1, 11),
|
||||
)
|
||||
|
||||
# create the value grid for plotting
|
||||
Z = np.apply_along_axis(
|
||||
lambda obs: V[(obs[0], obs[1], usable_ace)], axis=2, arr=np.dstack([X, Y])
|
||||
value = np.apply_along_axis(
|
||||
lambda obs: state_value[(obs[0], obs[1], usable_ace)],
|
||||
axis=2,
|
||||
arr=np.dstack([player_count, dealer_count]),
|
||||
)
|
||||
value_grid = X, Y, Z
|
||||
value_grid = player_count, dealer_count, value
|
||||
|
||||
# create the policy grid for plotting
|
||||
policy_grid = np.apply_along_axis(
|
||||
lambda obs: policy[(obs[0], obs[1], usable_ace)], axis=2, arr=np.dstack([X, Y])
|
||||
lambda obs: policy[(obs[0], obs[1], usable_ace)],
|
||||
axis=2,
|
||||
arr=np.dstack([player_count, dealer_count]),
|
||||
)
|
||||
return value_grid, policy_grid
|
||||
|
||||
|
||||
def create_plots(value_grid, policy_grid, title="N/A"):
|
||||
|
||||
def create_plots(value_grid, policy_grid, title: str):
|
||||
"""Creates a plot using a value and policy grid."""
|
||||
# create a new figure with 2 subplots (left: state values, right: policy)
|
||||
X, Y, Z = value_grid
|
||||
player_count, dealer_count, value = value_grid
|
||||
fig = plt.figure(figsize=plt.figaspect(0.4))
|
||||
fig.suptitle(title, fontsize=16)
|
||||
|
||||
# plot the state values
|
||||
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
|
||||
ax1.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap="viridis", edgecolor="none")
|
||||
ax1.plot_surface(
|
||||
player_count,
|
||||
dealer_count,
|
||||
value,
|
||||
rstride=1,
|
||||
cstride=1,
|
||||
cmap="viridis",
|
||||
edgecolor="none",
|
||||
)
|
||||
plt.xticks(range(12, 22), range(12, 22))
|
||||
plt.yticks(range(1, 11), ["A"] + list(range(2, 11)))
|
||||
ax1.set_title("State values: " + title)
|
||||
ax1.set_title(f"State values: {title}")
|
||||
ax1.set_xlabel("Player sum")
|
||||
ax1.set_ylabel("Dealer showing")
|
||||
ax1.zaxis.set_rotate_label(False)
|
||||
@@ -308,7 +390,7 @@ def create_plots(value_grid, policy_grid, title="N/A"):
|
||||
# plot the policy
|
||||
fig.add_subplot(1, 2, 2)
|
||||
ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False)
|
||||
ax2.set_title("Policy: " + title)
|
||||
ax2.set_title(f"Policy: {title}")
|
||||
ax2.set_xlabel("Player sum")
|
||||
ax2.set_ylabel("Dealer showing")
|
||||
ax2.set_xticklabels(range(12, 22))
|
||||
@@ -328,7 +410,6 @@ value_grid, policy_grid = create_grids(agent, usable_ace=True)
|
||||
fig1 = create_plots(value_grid, policy_grid, title="With usable ace")
|
||||
plt.show()
|
||||
|
||||
|
||||
# %%
|
||||
# .. image:: /_static/img/tutorials/blackjack_with_usable_ace.png
|
||||
#
|
||||
@@ -338,7 +419,6 @@ value_grid, policy_grid = create_grids(agent, usable_ace=False)
|
||||
fig2 = create_plots(value_grid, policy_grid, title="Without usable ace")
|
||||
plt.show()
|
||||
|
||||
|
||||
# %%
|
||||
# .. image:: /_static/img/tutorials/blackjack_without_usable_ace.png
|
||||
#
|
||||
@@ -346,7 +426,12 @@ plt.show()
|
||||
# so that any used resources by the environment will be closed.
|
||||
#
|
||||
|
||||
env.close()
|
||||
# %%
|
||||
# Think you can do better?
|
||||
# ------------------------------
|
||||
|
||||
# You can visualize the environment using the play function
|
||||
# and try to win a few games.
|
||||
|
||||
|
||||
# %%
|
||||
@@ -357,7 +442,7 @@ env.close()
|
||||
# It is recommended that you solve this environment by yourself (project
|
||||
# based learning is really effective!). You can apply your favorite
|
||||
# discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton &
|
||||
# Barto <http://incompleteideas.net/book/the-book-2nd.html>`__, section
|
||||
# Barto <http://incompleteideas.net/book/the-book-2nd.html>`_, section
|
||||
# 5.3) - this way you can compare your results directly to the book.
|
||||
#
|
||||
# Best of fun!
|
||||
|
Reference in New Issue
Block a user