Incorporated a few spelling and logical fixes and added training logs (#109)

This commit is contained in:
Till Zemann
2022-11-03 11:46:29 +01:00
committed by GitHub
parent f85b4851f7
commit 88067cafbd
4 changed files with 196 additions and 111 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

After

Width:  |  Height:  |  Size: 107 KiB

View File

@@ -4,7 +4,6 @@ Solving Blackjack with Q-Learning
""" """
# %% # %%
# .. image:: /_static/img/tutorials/blackjack_AE_loop.jpg # .. image:: /_static/img/tutorials/blackjack_AE_loop.jpg
# :width: 650 # :width: 650
@@ -17,10 +16,15 @@ Solving Blackjack with Q-Learning
# infamous for being beatable under certain conditions. This version of # infamous for being beatable under certain conditions. This version of
# the game uses an infinite deck (we draw the cards with replacement), so # the game uses an infinite deck (we draw the cards with replacement), so
# counting cards wont be a viable strategy in our simulated game. # counting cards wont be a viable strategy in our simulated game.
# Full documentation can be found at https://gymnasium.farama.org/environments/toy_text/blackjack
# #
# **Objective**: To win, your card sum should be greater than than the # **Objective**: To win, your card sum should be greater than the
# dealers without exceeding 21. # dealers without exceeding 21.
# #
# **Actions**: Agents can pick between two actions:
# - stand (0): the player takes no more cards
# - hit (1): the player will be given another card, however the player could get over 21 and bust
#
# **Approach**: To solve this environment by yourself, you can pick your # **Approach**: To solve this environment by yourself, you can pick your
# favorite discrete RL algorithm. The presented solution uses *Q-learning* # favorite discrete RL algorithm. The presented solution uses *Q-learning*
# (a model-free RL algorithm). # (a model-free RL algorithm).
@@ -35,13 +39,17 @@ Solving Blackjack with Q-Learning
# Author: Till Zemann # Author: Till Zemann
# License: MIT License # License: MIT License
from __future__ import annotations
from collections import defaultdict from collections import defaultdict
import gym import matplotlib.pyplot as plt
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.patches import Patch from matplotlib.patches import Patch
from tqdm import tqdm
import gymnasium as gym
# Let's start by creating the blackjack environment. # Let's start by creating the blackjack environment.
# Note: We are going to follow the rules from Sutton & Barto. # Note: We are going to follow the rules from Sutton & Barto.
@@ -49,15 +57,16 @@ from matplotlib.patches import Patch
env = gym.make("Blackjack-v1", sab=True) env = gym.make("Blackjack-v1", sab=True)
# %% # %%
# .. code:: py # .. code:: py
# #
# # Other possible environment configurations: # # Other possible environment configurations are:
# #
# env = gym.make('Blackjack-v1', natural=True, sab=False)`` # env = gym.make('Blackjack-v1', natural=True, sab=False)
# # Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).
# #
# env = gym.make('Blackjack-v1', natural=False, sab=False)`` # env = gym.make('Blackjack-v1', natural=False, sab=False)
# # Whether to follow the exact rules outlined in the book by Sutton and Barto. If `sab` is `True`, the keyword argument `natural` will be ignored.
# #
@@ -68,21 +77,18 @@ env = gym.make("Blackjack-v1", sab=True)
# First of all, we call ``env.reset()`` to start an episode. This function # First of all, we call ``env.reset()`` to start an episode. This function
# resets the environment to a starting position and returns an initial # resets the environment to a starting position and returns an initial
# ``observation``. We usually also set ``done = False``. This variable # ``observation``. We usually also set ``done = False``. This variable
# will be useful later to check if a game is terminated. In this tutorial # will be useful later to check if a game is terminated (i.e., the player wins or loses).
# we will use the terms observation and state synonymously but in more
# complex problems a state might differ from the observation it is based
# on.
# #
# reset the environment to get the first observation # reset the environment to get the first observation
done = False done = False
observation, info = env.reset() observation, info = env.reset()
print(observation) # observation = (16, 9, False)
# %% # %%
# Note that our observation is a 3-tuple consisting of 3 discrete values: # Note that our observation is a 3-tuple consisting of 3 values:
# #
# - The players current sum # - The players current sum
# - Value of the dealers face-up card # - Value of the dealers face-up card
@@ -106,13 +112,13 @@ print(observation)
# - ``reward``: This is the reward that the agent will receive after # - ``reward``: This is the reward that the agent will receive after
# taking the action. # taking the action.
# - ``terminated``: This is a boolean variable that indicates whether or # - ``terminated``: This is a boolean variable that indicates whether or
# not the episode is over. # not the environment has terminated.
# - ``truncated``: This is a boolean variable that also indicates whether # - ``truncated``: This is a boolean variable that also indicates whether
# the episode ended by early truncation. # the episode ended by early truncation, i.e., a time limit is reached.
# - ``info``: This is a dictionary that might contain additional # - ``info``: This is a dictionary that might contain additional
# information about the environment. # information about the environment.
# #
# The ``next_state``, ``reward``, and ``done`` variables are # The ``next_state``, ``reward``, ``terminated`` and ``truncated`` variables are
# self-explanatory, but the ``info`` variable requires some additional # self-explanatory, but the ``info`` variable requires some additional
# explanation. This variable contains a dictionary that might have some # explanation. This variable contains a dictionary that might have some
# extra information about the environment, but in the Blackjack-v1 # extra information about the environment, but in the Blackjack-v1
@@ -120,34 +126,29 @@ print(observation)
# info dictionary has a ``ale.lives`` key that tells us how many lives the # info dictionary has a ``ale.lives`` key that tells us how many lives the
# agent has left. If the agent has 0 lives, then the episode is over. # agent has left. If the agent has 0 lives, then the episode is over.
# #
# Blackjack-v1 doesnt have a ``env.render()`` function to render the # Note that it is not a good idea to call ``env.render()`` in your training
# environment, but in other environments you can use this function to # loop because rendering slows down training by a lot. Rather try to build
# watch the agent play. Important to note is that using ``env.render()`` # an extra loop to evaluate and showcase the agent after training.
# is optional - the environment is going to work even if you dont render
# it, but it can be helpful to see an episode rendered out to get an idea
# of how the current policy behaves. Note that it is not a good idea to
# call this function in your training loop because rendering slows down
# training by a lot. Rather try to build an extra loop to evaluate and
# showcase the agent after training.
# #
# sample a random action from all valid actions # sample a random action from all valid actions
action = env.action_space.sample() action = env.action_space.sample()
# action=1
# execute the action in our environment and receive infos from the environment # execute the action in our environment and receive infos from the environment
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
print("observation:", observation) # observation=(24, 10, False)
print("reward:", reward) # reward=-1.0
print("terminated:", terminated) # terminated=True
print("truncated:", truncated) # truncated=False
print("info:", info) # info={}
# %% # %%
# Once ``terminated = True`` or ``truncated=True``, we should stop the # Once ``terminated = True`` or ``truncated=True``, we should stop the
# current episode and begin a new one with ``env.reset()``. If you # current episode and begin a new one with ``env.reset()``. If you
# continue executing act`ons without resetting the environment, it still # continue executing actions without resetting the environment, it still
# responds but the output wont be useful for training (it might even be # responds but the output wont be useful for training (it might even be
# harmful if the agent learns on invalid data). # harmful if the agent learns on invalid data).
# #
@@ -159,7 +160,7 @@ print("info:", info)
# #
# Lets build a ``Q-learning agent`` to solve *Blackjack-v1*! Well need # Lets build a ``Q-learning agent`` to solve *Blackjack-v1*! Well need
# some functions for picking an action and updating the agents action # some functions for picking an action and updating the agents action
# values. To ensure that the agents expores the environment, one possible # values. To ensure that the agents explores the environment, one possible
# solution is the ``epsilon-greedy`` strategy, where we pick a random # solution is the ``epsilon-greedy`` strategy, where we pick a random
# action with the percentage ``epsilon`` and the greedy action (currently # action with the percentage ``epsilon`` and the greedy action (currently
# valued as the best) ``1 - epsilon``. # valued as the best) ``1 - epsilon``.
@@ -167,43 +168,69 @@ print("info:", info)
class BlackjackAgent: class BlackjackAgent:
def __init__(self, lr=1e-3, epsilon=0.1, epsilon_decay=1e-4): def __init__(
""" self,
Initialize an Reinforcement Learning agent with an empty dictionary learning_rate: float,
initial_epsilon: float,
epsilon_decay: float,
final_epsilon: float,
discount_factor: float = 0.95,
):
"""Initialize a Reinforcement Learning agent with an empty dictionary
of state-action values (q_values), a learning rate and an epsilon. of state-action values (q_values), a learning rate and an epsilon.
"""
self.q_values = defaultdict(
lambda: np.zeros(env.action_space.n)
) # maps a state to action values
self.lr = lr
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
def get_action(self, state): Args:
learning_rate: The learning rate
initial_epsilon: The initial epsilon value
epsilon_decay: The decay for epsilon
final_epsilon: The final epsilon value
discount_factor: The discount factor for computing the Q-value
"""
self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))
self.lr = learning_rate
self.discount_factor = discount_factor
self.epsilon = initial_epsilon
self.epsilon_decay = epsilon_decay
self.final_epsilon = final_epsilon
self.training_error = []
def get_action(self, obs: tuple[int, int, bool]) -> int:
""" """
Returns the best action with probability (1 - epsilon) Returns the best action with probability (1 - epsilon)
and a random action with probability epsilon to ensure exploration. otherwise a random action with probability epsilon to ensure exploration.
""" """
# with probability epsilon return a random action to explore the environment # with probability epsilon return a random action to explore the environment
if np.random.random() < self.epsilon: if np.random.random() < self.epsilon:
action = env.action_space.sample() return env.action_space.sample()
# with probability (1 - epsilon) act greedily (exploit) # with probability (1 - epsilon) act greedily (exploit)
else: else:
action = np.argmax(self.q_values[state]) return int(np.argmax(self.q_values[obs]))
return action
def update(self, state, action, reward, next_state, done): def update(
""" self,
Updates the Q-value of an action. obs: tuple[int, int, bool],
""" action: int,
old_q_value = self.q_values[state][action] reward: float,
max_future_q = np.max(self.q_values[next_state]) terminated: bool,
target = reward + self.lr * max_future_q * (1 - done) next_obs: tuple[int, int, bool],
self.q_values[state][action] = (1 - self.lr) * old_q_value + self.lr * target ):
"""Updates the Q-value of an action."""
future_q_value = (not terminated) * np.max(self.q_values[next_obs])
temporal_difference = (
reward + self.discount_factor * future_q_value - self.q_values[obs][action]
)
self.q_values[obs][action] = (
self.q_values[obs][action] + self.lr * temporal_difference
)
self.training_error.append(temporal_difference)
def decay_epsilon(self): def decay_epsilon(self):
self.epsilon = self.epsilon - epsilon_decay self.epsilon = max(self.final_epsilon, self.epsilon - epsilon_decay)
# %% # %%
@@ -216,89 +243,144 @@ class BlackjackAgent:
# #
# hyperparameters # hyperparameters
learning_rate = 1e-3 learning_rate = 0.01
start_epsilon = 0.8 n_episodes = 100_000
n_episodes = 200_000 start_epsilon = 1.0
epsilon_decay = start_epsilon / n_episodes # less exploration over time epsilon_decay = start_epsilon / (n_episodes / 2) # reduce the exploration over time
final_epsilon = 0.1
agent = BlackjackAgent( agent = BlackjackAgent(
lr=learning_rate, epsilon=start_epsilon, epsilon_decay=epsilon_decay learning_rate=learning_rate,
initial_epsilon=start_epsilon,
epsilon_decay=epsilon_decay,
final_epsilon=final_epsilon,
) )
def train(agent, n_episodes):
for episode in range(n_episodes):
# reset the environment
state, info = env.reset()
done = False
# play one episode
while not done:
action = agent.get_action(observation)
next_state, reward, terminated, truncated, info = env.step(action)
done = (
terminated or truncated
) # if the episode terminated or was truncated early, set done to True
agent.update(state, action, reward, next_state, done)
state = next_state
agent.update(state, action, reward, next_state, done)
# %% # %%
# Great, lets train! # Great, lets train!
# #
# Info: The current hyperparameters are set to quickly train a decent agent.
train(agent, n_episodes) # If you want to converge to the optimal policy, try increasing
# the n_episodes by 10x and lower the learning_rate (e.g. to 0.001).
# %%
# Visualizing the results
# ------------------------------
# #
def create_grids(agent, usable_ace=False): env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
for episode in tqdm(range(n_episodes)):
obs, info = env.reset()
done = False
# play one episode
while not done:
action = agent.get_action(obs)
next_obs, reward, terminated, truncated, info = env.step(action)
# update the agent
agent.update(obs, action, reward, terminated, next_obs)
# update if the environment is done and the current obs
done = terminated or truncated
obs = next_obs
agent.decay_epsilon()
# %%
# Visualizing the training
# ------------------------------
#
rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
axs[0].set_title("Episode rewards")
reward_moving_average = (
np.convolve(
np.array(env.return_queue).flatten(), np.ones(rolling_length), mode="valid"
)
/ rolling_length
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
axs[1].set_title("Episode lengths")
length_moving_average = (
np.convolve(
np.array(env.length_queue).flatten(), np.ones(rolling_length), mode="same"
)
/ rolling_length
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)
axs[2].set_title("Training Error")
training_error_moving_average = (
np.convolve(np.array(agent.training_error), np.ones(rolling_length), mode="same")
/ rolling_length
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()
# %%
# .. image:: /_static/img/tutorials/blackjack_training_plots.png
#
# %%
# Visualising the policy
# ------------------------------
def create_grids(agent, usable_ace=False):
"""Create value and policy grid given an agent."""
# convert our state-action values to state values # convert our state-action values to state values
# and build a policy dictionary that maps observations to actions # and build a policy dictionary that maps observations to actions
V = defaultdict(float) state_value = defaultdict(float)
policy = defaultdict(int) policy = defaultdict(int)
for obs, action_values in agent.q_values.items(): for obs, action_values in agent.q_values.items():
V[obs] = np.max(action_values) state_value[obs] = float(np.max(action_values))
policy[obs] = np.argmax(action_values) policy[obs] = int(np.argmax(action_values))
X, Y = np.meshgrid( player_count, dealer_count = np.meshgrid(
np.arange(12, 22), np.arange(1, 11) # players count # players count, dealers face-up card
) # dealers face-up card np.arange(12, 22),
np.arange(1, 11),
)
# create the value grid for plotting # create the value grid for plotting
Z = np.apply_along_axis( value = np.apply_along_axis(
lambda obs: V[(obs[0], obs[1], usable_ace)], axis=2, arr=np.dstack([X, Y]) lambda obs: state_value[(obs[0], obs[1], usable_ace)],
axis=2,
arr=np.dstack([player_count, dealer_count]),
) )
value_grid = X, Y, Z value_grid = player_count, dealer_count, value
# create the policy grid for plotting # create the policy grid for plotting
policy_grid = np.apply_along_axis( policy_grid = np.apply_along_axis(
lambda obs: policy[(obs[0], obs[1], usable_ace)], axis=2, arr=np.dstack([X, Y]) lambda obs: policy[(obs[0], obs[1], usable_ace)],
axis=2,
arr=np.dstack([player_count, dealer_count]),
) )
return value_grid, policy_grid return value_grid, policy_grid
def create_plots(value_grid, policy_grid, title="N/A"): def create_plots(value_grid, policy_grid, title: str):
"""Creates a plot using a value and policy grid."""
# create a new figure with 2 subplots (left: state values, right: policy) # create a new figure with 2 subplots (left: state values, right: policy)
X, Y, Z = value_grid player_count, dealer_count, value = value_grid
fig = plt.figure(figsize=plt.figaspect(0.4)) fig = plt.figure(figsize=plt.figaspect(0.4))
fig.suptitle(title, fontsize=16) fig.suptitle(title, fontsize=16)
# plot the state values # plot the state values
ax1 = fig.add_subplot(1, 2, 1, projection="3d") ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax1.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap="viridis", edgecolor="none") ax1.plot_surface(
player_count,
dealer_count,
value,
rstride=1,
cstride=1,
cmap="viridis",
edgecolor="none",
)
plt.xticks(range(12, 22), range(12, 22)) plt.xticks(range(12, 22), range(12, 22))
plt.yticks(range(1, 11), ["A"] + list(range(2, 11))) plt.yticks(range(1, 11), ["A"] + list(range(2, 11)))
ax1.set_title("State values: " + title) ax1.set_title(f"State values: {title}")
ax1.set_xlabel("Player sum") ax1.set_xlabel("Player sum")
ax1.set_ylabel("Dealer showing") ax1.set_ylabel("Dealer showing")
ax1.zaxis.set_rotate_label(False) ax1.zaxis.set_rotate_label(False)
@@ -308,7 +390,7 @@ def create_plots(value_grid, policy_grid, title="N/A"):
# plot the policy # plot the policy
fig.add_subplot(1, 2, 2) fig.add_subplot(1, 2, 2)
ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False) ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False)
ax2.set_title("Policy: " + title) ax2.set_title(f"Policy: {title}")
ax2.set_xlabel("Player sum") ax2.set_xlabel("Player sum")
ax2.set_ylabel("Dealer showing") ax2.set_ylabel("Dealer showing")
ax2.set_xticklabels(range(12, 22)) ax2.set_xticklabels(range(12, 22))
@@ -328,7 +410,6 @@ value_grid, policy_grid = create_grids(agent, usable_ace=True)
fig1 = create_plots(value_grid, policy_grid, title="With usable ace") fig1 = create_plots(value_grid, policy_grid, title="With usable ace")
plt.show() plt.show()
# %% # %%
# .. image:: /_static/img/tutorials/blackjack_with_usable_ace.png # .. image:: /_static/img/tutorials/blackjack_with_usable_ace.png
# #
@@ -338,7 +419,6 @@ value_grid, policy_grid = create_grids(agent, usable_ace=False)
fig2 = create_plots(value_grid, policy_grid, title="Without usable ace") fig2 = create_plots(value_grid, policy_grid, title="Without usable ace")
plt.show() plt.show()
# %% # %%
# .. image:: /_static/img/tutorials/blackjack_without_usable_ace.png # .. image:: /_static/img/tutorials/blackjack_without_usable_ace.png
# #
@@ -346,7 +426,12 @@ plt.show()
# so that any used resources by the environment will be closed. # so that any used resources by the environment will be closed.
# #
env.close() # %%
# Think you can do better?
# ------------------------------
# You can visualize the environment using the play function
# and try to win a few games.
# %% # %%
@@ -357,7 +442,7 @@ env.close()
# It is recommended that you solve this environment by yourself (project # It is recommended that you solve this environment by yourself (project
# based learning is really effective!). You can apply your favorite # based learning is really effective!). You can apply your favorite
# discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton & # discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton &
# Barto <http://incompleteideas.net/book/the-book-2nd.html>`__, section # Barto <http://incompleteideas.net/book/the-book-2nd.html>`_, section
# 5.3) - this way you can compare your results directly to the book. # 5.3) - this way you can compare your results directly to the book.
# #
# Best of fun! # Best of fun!