Incorporated a few spelling and logical fixes and added training logs (#109)

This commit is contained in:
Till Zemann
2022-11-03 11:46:29 +01:00
committed by GitHub
parent f85b4851f7
commit 88067cafbd
4 changed files with 196 additions and 111 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

After

Width:  |  Height:  |  Size: 107 KiB

View File

@@ -4,7 +4,6 @@ Solving Blackjack with Q-Learning
"""
# %%
# .. image:: /_static/img/tutorials/blackjack_AE_loop.jpg
# :width: 650
@@ -17,10 +16,15 @@ Solving Blackjack with Q-Learning
# infamous for being beatable under certain conditions. This version of
# the game uses an infinite deck (we draw the cards with replacement), so
# counting cards wont be a viable strategy in our simulated game.
# Full documentation can be found at https://gymnasium.farama.org/environments/toy_text/blackjack
#
# **Objective**: To win, your card sum should be greater than than the
# **Objective**: To win, your card sum should be greater than the
# dealers without exceeding 21.
#
# **Actions**: Agents can pick between two actions:
# - stand (0): the player takes no more cards
# - hit (1): the player will be given another card, however the player could get over 21 and bust
#
# **Approach**: To solve this environment by yourself, you can pick your
# favorite discrete RL algorithm. The presented solution uses *Q-learning*
# (a model-free RL algorithm).
@@ -35,13 +39,17 @@ Solving Blackjack with Q-Learning
# Author: Till Zemann
# License: MIT License
from __future__ import annotations
from collections import defaultdict
import gym
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.patches import Patch
from tqdm import tqdm
import gymnasium as gym
# Let's start by creating the blackjack environment.
# Note: We are going to follow the rules from Sutton & Barto.
@@ -49,15 +57,16 @@ from matplotlib.patches import Patch
env = gym.make("Blackjack-v1", sab=True)
# %%
# .. code:: py
#
# # Other possible environment configurations:
# # Other possible environment configurations are:
#
# env = gym.make('Blackjack-v1', natural=True, sab=False)``
# env = gym.make('Blackjack-v1', natural=True, sab=False)
# # Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).
#
# env = gym.make('Blackjack-v1', natural=False, sab=False)``
# env = gym.make('Blackjack-v1', natural=False, sab=False)
# # Whether to follow the exact rules outlined in the book by Sutton and Barto. If `sab` is `True`, the keyword argument `natural` will be ignored.
#
@@ -68,21 +77,18 @@ env = gym.make("Blackjack-v1", sab=True)
# First of all, we call ``env.reset()`` to start an episode. This function
# resets the environment to a starting position and returns an initial
# ``observation``. We usually also set ``done = False``. This variable
# will be useful later to check if a game is terminated. In this tutorial
# we will use the terms observation and state synonymously but in more
# complex problems a state might differ from the observation it is based
# on.
# will be useful later to check if a game is terminated (i.e., the player wins or loses).
#
# reset the environment to get the first observation
done = False
observation, info = env.reset()
print(observation)
# observation = (16, 9, False)
# %%
# Note that our observation is a 3-tuple consisting of 3 discrete values:
# Note that our observation is a 3-tuple consisting of 3 values:
#
# - The players current sum
# - Value of the dealers face-up card
@@ -106,13 +112,13 @@ print(observation)
# - ``reward``: This is the reward that the agent will receive after
# taking the action.
# - ``terminated``: This is a boolean variable that indicates whether or
# not the episode is over.
# not the environment has terminated.
# - ``truncated``: This is a boolean variable that also indicates whether
# the episode ended by early truncation.
# the episode ended by early truncation, i.e., a time limit is reached.
# - ``info``: This is a dictionary that might contain additional
# information about the environment.
#
# The ``next_state``, ``reward``, and ``done`` variables are
# The ``next_state``, ``reward``, ``terminated`` and ``truncated`` variables are
# self-explanatory, but the ``info`` variable requires some additional
# explanation. This variable contains a dictionary that might have some
# extra information about the environment, but in the Blackjack-v1
@@ -120,34 +126,29 @@ print(observation)
# info dictionary has a ``ale.lives`` key that tells us how many lives the
# agent has left. If the agent has 0 lives, then the episode is over.
#
# Blackjack-v1 doesnt have a ``env.render()`` function to render the
# environment, but in other environments you can use this function to
# watch the agent play. Important to note is that using ``env.render()``
# is optional - the environment is going to work even if you dont render
# it, but it can be helpful to see an episode rendered out to get an idea
# of how the current policy behaves. Note that it is not a good idea to
# call this function in your training loop because rendering slows down
# training by a lot. Rather try to build an extra loop to evaluate and
# showcase the agent after training.
# Note that it is not a good idea to call ``env.render()`` in your training
# loop because rendering slows down training by a lot. Rather try to build
# an extra loop to evaluate and showcase the agent after training.
#
# sample a random action from all valid actions
action = env.action_space.sample()
# action=1
# execute the action in our environment and receive infos from the environment
observation, reward, terminated, truncated, info = env.step(action)
print("observation:", observation)
print("reward:", reward)
print("terminated:", terminated)
print("truncated:", truncated)
print("info:", info)
# observation=(24, 10, False)
# reward=-1.0
# terminated=True
# truncated=False
# info={}
# %%
# Once ``terminated = True`` or ``truncated=True``, we should stop the
# current episode and begin a new one with ``env.reset()``. If you
# continue executing act`ons without resetting the environment, it still
# continue executing actions without resetting the environment, it still
# responds but the output wont be useful for training (it might even be
# harmful if the agent learns on invalid data).
#
@@ -159,7 +160,7 @@ print("info:", info)
#
# Lets build a ``Q-learning agent`` to solve *Blackjack-v1*! Well need
# some functions for picking an action and updating the agents action
# values. To ensure that the agents expores the environment, one possible
# values. To ensure that the agents explores the environment, one possible
# solution is the ``epsilon-greedy`` strategy, where we pick a random
# action with the percentage ``epsilon`` and the greedy action (currently
# valued as the best) ``1 - epsilon``.
@@ -167,43 +168,69 @@ print("info:", info)
class BlackjackAgent:
def __init__(self, lr=1e-3, epsilon=0.1, epsilon_decay=1e-4):
"""
Initialize an Reinforcement Learning agent with an empty dictionary
def __init__(
self,
learning_rate: float,
initial_epsilon: float,
epsilon_decay: float,
final_epsilon: float,
discount_factor: float = 0.95,
):
"""Initialize a Reinforcement Learning agent with an empty dictionary
of state-action values (q_values), a learning rate and an epsilon.
"""
self.q_values = defaultdict(
lambda: np.zeros(env.action_space.n)
) # maps a state to action values
self.lr = lr
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
def get_action(self, state):
Args:
learning_rate: The learning rate
initial_epsilon: The initial epsilon value
epsilon_decay: The decay for epsilon
final_epsilon: The final epsilon value
discount_factor: The discount factor for computing the Q-value
"""
self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))
self.lr = learning_rate
self.discount_factor = discount_factor
self.epsilon = initial_epsilon
self.epsilon_decay = epsilon_decay
self.final_epsilon = final_epsilon
self.training_error = []
def get_action(self, obs: tuple[int, int, bool]) -> int:
"""
Returns the best action with probability (1 - epsilon)
and a random action with probability epsilon to ensure exploration.
otherwise a random action with probability epsilon to ensure exploration.
"""
# with probability epsilon return a random action to explore the environment
if np.random.random() < self.epsilon:
action = env.action_space.sample()
return env.action_space.sample()
# with probability (1 - epsilon) act greedily (exploit)
else:
action = np.argmax(self.q_values[state])
return action
return int(np.argmax(self.q_values[obs]))
def update(self, state, action, reward, next_state, done):
"""
Updates the Q-value of an action.
"""
old_q_value = self.q_values[state][action]
max_future_q = np.max(self.q_values[next_state])
target = reward + self.lr * max_future_q * (1 - done)
self.q_values[state][action] = (1 - self.lr) * old_q_value + self.lr * target
def update(
self,
obs: tuple[int, int, bool],
action: int,
reward: float,
terminated: bool,
next_obs: tuple[int, int, bool],
):
"""Updates the Q-value of an action."""
future_q_value = (not terminated) * np.max(self.q_values[next_obs])
temporal_difference = (
reward + self.discount_factor * future_q_value - self.q_values[obs][action]
)
self.q_values[obs][action] = (
self.q_values[obs][action] + self.lr * temporal_difference
)
self.training_error.append(temporal_difference)
def decay_epsilon(self):
self.epsilon = self.epsilon - epsilon_decay
self.epsilon = max(self.final_epsilon, self.epsilon - epsilon_decay)
# %%
@@ -216,89 +243,144 @@ class BlackjackAgent:
#
# hyperparameters
learning_rate = 1e-3
start_epsilon = 0.8
n_episodes = 200_000
epsilon_decay = start_epsilon / n_episodes # less exploration over time
learning_rate = 0.01
n_episodes = 100_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2) # reduce the exploration over time
final_epsilon = 0.1
agent = BlackjackAgent(
lr=learning_rate, epsilon=start_epsilon, epsilon_decay=epsilon_decay
learning_rate=learning_rate,
initial_epsilon=start_epsilon,
epsilon_decay=epsilon_decay,
final_epsilon=final_epsilon,
)
def train(agent, n_episodes):
for episode in range(n_episodes):
# reset the environment
state, info = env.reset()
done = False
# play one episode
while not done:
action = agent.get_action(observation)
next_state, reward, terminated, truncated, info = env.step(action)
done = (
terminated or truncated
) # if the episode terminated or was truncated early, set done to True
agent.update(state, action, reward, next_state, done)
state = next_state
agent.update(state, action, reward, next_state, done)
# %%
# Great, lets train!
#
train(agent, n_episodes)
# %%
# Visualizing the results
# ------------------------------
# Info: The current hyperparameters are set to quickly train a decent agent.
# If you want to converge to the optimal policy, try increasing
# the n_episodes by 10x and lower the learning_rate (e.g. to 0.001).
#
def create_grids(agent, usable_ace=False):
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
for episode in tqdm(range(n_episodes)):
obs, info = env.reset()
done = False
# play one episode
while not done:
action = agent.get_action(obs)
next_obs, reward, terminated, truncated, info = env.step(action)
# update the agent
agent.update(obs, action, reward, terminated, next_obs)
# update if the environment is done and the current obs
done = terminated or truncated
obs = next_obs
agent.decay_epsilon()
# %%
# Visualizing the training
# ------------------------------
#
rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
axs[0].set_title("Episode rewards")
reward_moving_average = (
np.convolve(
np.array(env.return_queue).flatten(), np.ones(rolling_length), mode="valid"
)
/ rolling_length
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
axs[1].set_title("Episode lengths")
length_moving_average = (
np.convolve(
np.array(env.length_queue).flatten(), np.ones(rolling_length), mode="same"
)
/ rolling_length
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)
axs[2].set_title("Training Error")
training_error_moving_average = (
np.convolve(np.array(agent.training_error), np.ones(rolling_length), mode="same")
/ rolling_length
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()
# %%
# .. image:: /_static/img/tutorials/blackjack_training_plots.png
#
# %%
# Visualising the policy
# ------------------------------
def create_grids(agent, usable_ace=False):
"""Create value and policy grid given an agent."""
# convert our state-action values to state values
# and build a policy dictionary that maps observations to actions
V = defaultdict(float)
state_value = defaultdict(float)
policy = defaultdict(int)
for obs, action_values in agent.q_values.items():
V[obs] = np.max(action_values)
policy[obs] = np.argmax(action_values)
state_value[obs] = float(np.max(action_values))
policy[obs] = int(np.argmax(action_values))
X, Y = np.meshgrid(
np.arange(12, 22), np.arange(1, 11) # players count
) # dealers face-up card
player_count, dealer_count = np.meshgrid(
# players count, dealers face-up card
np.arange(12, 22),
np.arange(1, 11),
)
# create the value grid for plotting
Z = np.apply_along_axis(
lambda obs: V[(obs[0], obs[1], usable_ace)], axis=2, arr=np.dstack([X, Y])
value = np.apply_along_axis(
lambda obs: state_value[(obs[0], obs[1], usable_ace)],
axis=2,
arr=np.dstack([player_count, dealer_count]),
)
value_grid = X, Y, Z
value_grid = player_count, dealer_count, value
# create the policy grid for plotting
policy_grid = np.apply_along_axis(
lambda obs: policy[(obs[0], obs[1], usable_ace)], axis=2, arr=np.dstack([X, Y])
lambda obs: policy[(obs[0], obs[1], usable_ace)],
axis=2,
arr=np.dstack([player_count, dealer_count]),
)
return value_grid, policy_grid
def create_plots(value_grid, policy_grid, title="N/A"):
def create_plots(value_grid, policy_grid, title: str):
"""Creates a plot using a value and policy grid."""
# create a new figure with 2 subplots (left: state values, right: policy)
X, Y, Z = value_grid
player_count, dealer_count, value = value_grid
fig = plt.figure(figsize=plt.figaspect(0.4))
fig.suptitle(title, fontsize=16)
# plot the state values
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax1.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap="viridis", edgecolor="none")
ax1.plot_surface(
player_count,
dealer_count,
value,
rstride=1,
cstride=1,
cmap="viridis",
edgecolor="none",
)
plt.xticks(range(12, 22), range(12, 22))
plt.yticks(range(1, 11), ["A"] + list(range(2, 11)))
ax1.set_title("State values: " + title)
ax1.set_title(f"State values: {title}")
ax1.set_xlabel("Player sum")
ax1.set_ylabel("Dealer showing")
ax1.zaxis.set_rotate_label(False)
@@ -308,7 +390,7 @@ def create_plots(value_grid, policy_grid, title="N/A"):
# plot the policy
fig.add_subplot(1, 2, 2)
ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False)
ax2.set_title("Policy: " + title)
ax2.set_title(f"Policy: {title}")
ax2.set_xlabel("Player sum")
ax2.set_ylabel("Dealer showing")
ax2.set_xticklabels(range(12, 22))
@@ -328,7 +410,6 @@ value_grid, policy_grid = create_grids(agent, usable_ace=True)
fig1 = create_plots(value_grid, policy_grid, title="With usable ace")
plt.show()
# %%
# .. image:: /_static/img/tutorials/blackjack_with_usable_ace.png
#
@@ -338,7 +419,6 @@ value_grid, policy_grid = create_grids(agent, usable_ace=False)
fig2 = create_plots(value_grid, policy_grid, title="Without usable ace")
plt.show()
# %%
# .. image:: /_static/img/tutorials/blackjack_without_usable_ace.png
#
@@ -346,7 +426,12 @@ plt.show()
# so that any used resources by the environment will be closed.
#
env.close()
# %%
# Think you can do better?
# ------------------------------
# You can visualize the environment using the play function
# and try to win a few games.
# %%
@@ -357,7 +442,7 @@ env.close()
# It is recommended that you solve this environment by yourself (project
# based learning is really effective!). You can apply your favorite
# discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton &
# Barto <http://incompleteideas.net/book/the-book-2nd.html>`__, section
# Barto <http://incompleteideas.net/book/the-book-2nd.html>`_, section
# 5.3) - this way you can compare your results directly to the book.
#
# Best of fun!