2023-02-21 05:38:11 -06:00
""" This module provides a Blackjack functional environment and Gymnasium environment wrapper BlackJackJaxEnv. """
import math
import os
from typing import NamedTuple , Optional , Tuple , Union
import jax
import jax . numpy as jnp
import numpy as np
2023-12-17 15:03:06 +01:00
from flax import struct
2023-02-21 05:38:11 -06:00
from jax import random
from jax . random import PRNGKey
from gymnasium import spaces
2023-11-07 13:27:25 +00:00
from gymnasium . envs . functional_jax_env import FunctionalJaxEnv
2023-02-21 05:38:11 -06:00
from gymnasium . error import DependencyNotInstalled
2023-11-07 13:27:25 +00:00
from gymnasium . functional import ActType , FuncEnv , StateType
2023-02-21 05:38:11 -06:00
from gymnasium . utils import EzPickle , seeding
from gymnasium . wrappers import HumanRendering
RenderStateType = Tuple [ " pygame.Surface " , str , int ] # type: ignore # noqa: F821
deck = jnp . array ( [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 10 , 10 , 10 ] )
class EnvState ( NamedTuple ) :
""" A named tuple which contains the full state of the blackjack game. """
2023-07-03 23:53:57 +02:00
dealer_hand : jax . Array
player_hand : jax . Array
2023-02-21 05:38:11 -06:00
dealer_cards : int
player_cards : int
done : int
def cmp ( a , b ) :
""" Returns 1 if a > b, otherwise returns -1. """
return ( a > b ) . astype ( int ) - ( a < b ) . astype ( int )
def random_card ( key ) :
""" Draws a randowm card (with replacement). """
key = random . split ( key ) [ 0 ]
choice = random . choice ( key , deck , shape = ( 1 , ) )
return choice [ 0 ] . astype ( int ) , key
def draw_hand ( key , hand ) :
""" Draws a starting hand of two random cards. """
new_card , key = random_card ( key )
hand = hand . at [ 0 ] . set ( new_card )
new_card , key = random_card ( key )
hand = hand . at [ 1 ] . set ( new_card )
return hand , key
def draw_card ( key , hand , index ) :
""" Draws a new card and adds it to a hand. """
new_card , key = random_card ( key )
hand = hand . at [ index ] . set ( new_card )
return key , hand , index + 1
def usable_ace ( hand ) :
""" Checks to se if a hand has a usable ace. """
return jnp . logical_and ( ( jnp . count_nonzero ( hand == 1 ) > 0 ) , ( sum ( hand ) + 10 < = 21 ) )
def take ( env_state ) :
""" This function is called if the player has decided to take a card. """
state , key = env_state
dealer_hand = state . dealer_hand
player_hand = state . player_hand
dealer_cards = state . dealer_cards
player_cards = state . player_cards
key , new_player_hand , _ = draw_card ( key , player_hand , player_cards )
new_player_cards = player_cards + 1
# done is set to zero here because it is determined later whether the player is bust
return (
EnvState (
dealer_hand = dealer_hand ,
player_hand = new_player_hand ,
dealer_cards = dealer_cards ,
player_cards = new_player_cards ,
done = 0 ,
) ,
key ,
)
def dealer_stop ( val ) :
""" This function determines if the dealer should stop drawing. """
return sum_hand ( val [ 1 ] ) < 17
def draw_card_wrapper ( val ) :
""" Wrapper function for draw_card. """
return draw_card ( * val )
def notake ( env_state ) :
""" This function is called if the player has decided to not take a card.
Calling this function ends the active portion
of the game and turns control over to the dealer .
"""
state , key = env_state
dealer_hand = state . dealer_hand
player_hand = state . player_hand
dealer_cards = state . dealer_cards
player_cards = state . player_cards
key , dealer_hand , dealer_cards = jax . lax . while_loop (
dealer_stop ,
draw_card_wrapper ,
( key , dealer_hand , dealer_cards ) ,
)
return (
EnvState (
dealer_hand = dealer_hand ,
player_hand = player_hand ,
dealer_cards = dealer_cards ,
player_cards = player_cards ,
done = 1 ,
) ,
key ,
)
def sum_hand ( hand ) :
""" Returns the total points in a hand. """
return sum ( hand ) + ( 10 * usable_ace ( hand ) )
def is_bust ( hand ) :
""" Returns whether or not the hand is a bust. """
return sum_hand ( hand ) > 21
def score ( hand ) :
""" Returns the score for a hand(0 if a bust). """
return ( jnp . logical_not ( is_bust ( hand ) ) ) * sum_hand ( hand )
def is_natural ( hand ) :
""" Returns if the hand is a natural blackjack. """
return jnp . logical_and (
jnp . logical_and (
jnp . count_nonzero ( hand ) == 2 , ( jnp . count_nonzero ( hand == 1 ) > 0 )
) ,
( jnp . count_nonzero ( hand == 10 ) > 0 ) ,
)
2023-12-17 15:03:06 +01:00
@struct.dataclass
class BlackJackParams :
""" Parameters for the jax Blackjack environment. """
natural : bool = False
sutton_and_barto : bool = True
2023-02-21 05:38:11 -06:00
class BlackjackFunctional (
2023-12-17 15:03:06 +01:00
FuncEnv [ jax . Array , jax . Array , int , float , bool , RenderStateType , BlackJackParams ]
2023-02-21 05:38:11 -06:00
) :
""" Blackjack is a card game where the goal is to beat the dealer by obtaining cards that sum to closer to 21 (without going over 21) than the dealers cards.
### Description
Card Values :
- Face cards ( Jack , Queen , King ) have a point value of 10.
- Aces can either count as 11 ( called a ' usable ace ' ) or 1.
- Numerical cards ( 2 - 9 ) have a value equal to their number .
This game is played with an infinite deck ( or with replacement ) .
The game starts with the dealer having one face up and one face down card ,
while the player has two face up cards .
The player can request additional cards ( hit , action = 1 ) until they decide to stop ( stick , action = 0 )
or exceed 21 ( bust , immediate loss ) .
After the player sticks , the dealer reveals their facedown card , and draws
until their sum is 17 or greater . If the dealer goes bust , the player wins .
If neither the player nor the dealer busts , the outcome ( win , lose , draw ) is
decided by whose sum is closer to 21.
### Action Space
There are two actions : stick ( 0 ) , and hit ( 1 ) .
### Observation Space
The observation consists of a 3 - tuple containing : the player ' s current sum,
the value of the dealer ' s one showing card (1-10 where 1 is ace),
and whether the player holds a usable ace ( 0 or 1 ) .
This environment corresponds to the version of the blackjack problem
described in Example 5.1 in Reinforcement Learning : An Introduction
by Sutton and Barto ( http : / / incompleteideas . net / book / the - book - 2 nd . html ) .
### Rewards
- win game : + 1
- lose game : - 1
- draw game : 0
- win game with natural blackjack :
+ 1.5 ( if < a href = " #nat " > natural < / a > is True )
+ 1 ( if < a href = " #nat " > natural < / a > is False )
### Arguments
` ` `
gym . make ( ' Jax-Blackjack-v0 ' , natural = False , sutton_and_barto = False )
` ` `
< a id = " nat " > ` natural = False ` < / a > : Whether to give an additional reward for
starting with a natural blackjack , i . e . starting with an ace and ten ( sum is 21 ) .
< a id = " sutton_and_barto " > ` sutton_and_barto = False ` < / a > : Whether to follow the exact rules outlined in the book by
Sutton and Barto . If ` sutton_and_barto ` is ` True ` , the keyword argument ` natural ` will be ignored .
If the player achieves a natural blackjack and the dealer does not , the player
will win ( i . e . get a reward of + 1 ) . The reverse rule does not apply .
If both the player and the dealer get a natural , it will be a draw ( i . e . reward 0 ) .
### Version History
* v0 : Initial version release ( 0.0 .0 ) , adapted from original gym blackjack v1
"""
action_space = spaces . Discrete ( 2 )
observation_space = spaces . Box (
low = np . array ( [ 1 , 1 , 0 ] ) , high = np . array ( [ 32 , 11 , 1 ] ) , shape = ( 3 , ) , dtype = np . int32
)
metadata = {
" render_modes " : [ " rgb_array " ] ,
" render_fps " : 4 ,
}
2023-12-17 15:03:06 +01:00
def transition (
self ,
state : EnvState ,
action : Union [ int , jax . Array ] ,
key : PRNGKey ,
params : BlackJackParams = BlackJackParams ,
) :
2023-02-21 05:38:11 -06:00
""" The blackjack environment ' s state transition function. """
env_state = jax . lax . cond ( action , take , notake , ( state , key ) )
hand_state , key = env_state
dealer_hand = hand_state . dealer_hand
player_hand = hand_state . player_hand
dealer_cards = hand_state . dealer_cards
player_cards = hand_state . player_cards
# note that only a bust or player action ends the round, the player
# can still request another card with 21 cards
done = ( is_bust ( player_hand ) * action ) + ( ( jnp . logical_not ( action ) ) * 1 )
new_state = EnvState (
dealer_hand = dealer_hand ,
player_hand = player_hand ,
dealer_cards = dealer_cards ,
player_cards = player_cards ,
done = done ,
)
return new_state
2023-12-17 15:03:06 +01:00
def initial ( self , rng : PRNGKey , params : BlackJackParams = BlackJackParams ) :
2023-02-21 05:38:11 -06:00
""" Blackjack initial observataion function. """
player_hand = jnp . zeros ( 21 )
dealer_hand = jnp . zeros ( 21 )
player_hand , rng = draw_hand ( rng , player_hand )
dealer_hand , rng = draw_hand ( rng , dealer_hand )
dealer_cards = 2
player_cards = 2
state = EnvState (
dealer_hand = dealer_hand ,
player_hand = player_hand ,
dealer_cards = dealer_cards ,
player_cards = player_cards ,
done = 0 ,
)
return state
2023-12-17 15:03:06 +01:00
def observation (
self , state : EnvState , params : BlackJackParams = BlackJackParams
) - > jax . Array :
2023-02-21 05:38:11 -06:00
""" Blackjack observation. """
return jnp . array (
[
sum_hand ( state . player_hand ) ,
state . dealer_hand [ 0 ] ,
usable_ace ( state . player_hand ) * 1.0 ,
] ,
dtype = np . int32 ,
)
2023-12-17 15:03:06 +01:00
def terminal (
self , state : EnvState , params : BlackJackParams = BlackJackParams
) - > jax . Array :
2023-02-21 05:38:11 -06:00
""" Determines if a particular Blackjack observation is terminal. """
return ( state . done ) > 0
def reward (
2023-12-17 15:03:06 +01:00
self ,
state : EnvState ,
action : ActType ,
next_state : StateType ,
params : BlackJackParams = BlackJackParams ,
2023-07-03 23:53:57 +02:00
) - > jax . Array :
2023-02-21 05:38:11 -06:00
""" Calculates reward from a state. """
state = next_state
dealer_hand = state . dealer_hand
player_hand = state . player_hand
# -1 reward if the player busts, otherwise +1 if better than dealer, 0 if tie, -1 if loss.
reward = (
0.0
+ ( is_bust ( player_hand ) * - 1 * action )
+ ( ( jnp . logical_not ( action ) ) * cmp ( score ( player_hand ) , score ( dealer_hand ) ) )
)
# in the natural setting, if the player wins with a natural blackjack, then reward is 1.5
2023-12-17 15:03:06 +01:00
if params . natural and not params . sutton_and_barto :
2023-02-21 05:38:11 -06:00
condition = jnp . logical_and ( is_natural ( player_hand ) , ( reward == 1 ) )
reward = reward * jnp . logical_not ( condition ) + 1.5 * condition
# in the sutton and barto setting, if the player gets a natural blackjack and the dealer gets
# a non-natural blackjack, the player wins. A dealer natural blackjack and a player
# non-natural blackjack should result in a tie.
2023-12-17 15:03:06 +01:00
if params . sutton_and_barto :
2023-02-21 05:38:11 -06:00
condition = jnp . logical_and (
is_natural ( player_hand ) , jnp . logical_not ( is_natural ( dealer_hand ) )
)
reward = reward * jnp . logical_not ( condition ) + 1 * condition
return reward
def render_init (
self , screen_width : int = 600 , screen_height : int = 500
) - > RenderStateType :
""" Returns an initial render state. """
try :
import pygame
except ImportError :
raise DependencyNotInstalled (
" pygame is not installed, run `pip install gymnasium[classic_control]` "
)
rng = seeding . np_random ( 0 ) [ 0 ]
suits = [ " C " , " D " , " H " , " S " ]
dealer_top_card_suit = rng . choice ( suits )
dealer_top_card_value_str = rng . choice ( [ " J " , " Q " , " K " ] )
pygame . init ( )
screen = pygame . Surface ( ( screen_width , screen_height ) )
return screen , dealer_top_card_value_str , dealer_top_card_suit
def render_image (
self ,
state : StateType ,
render_state : RenderStateType ,
2023-12-17 15:03:06 +01:00
params : BlackJackParams = BlackJackParams ,
2023-02-21 05:38:11 -06:00
) - > Tuple [ RenderStateType , np . ndarray ] :
""" Renders an image from a state. """
try :
import pygame
except ImportError :
raise DependencyNotInstalled (
" pygame is not installed, run `pip install gymnasium[toy_text]` "
)
screen , dealer_top_card_value_str , dealer_top_card_suit = render_state
player_sum , dealer_card_value , usable_ace = self . observation ( state )
screen_width , screen_height = 600 , 500
card_img_height = screen_height / / 3
card_img_width = int ( card_img_height * 142 / 197 )
spacing = screen_height / / 20
bg_color = ( 7 , 99 , 36 )
white = ( 255 , 255 , 255 )
if dealer_card_value == 1 :
display_card_value = " A "
elif dealer_card_value == 10 :
display_card_value = dealer_top_card_value_str
else :
display_card_value = str ( math . floor ( dealer_card_value ) )
screen . fill ( bg_color )
def get_image ( path ) :
cwd = os . path . dirname ( __file__ )
cwd = os . path . join ( cwd , " .. " )
cwd = os . path . join ( cwd , " toy_text " )
image = pygame . image . load ( os . path . join ( cwd , path ) )
return image
def get_font ( path , size ) :
cwd = os . path . dirname ( __file__ )
cwd = os . path . join ( cwd , " .. " )
cwd = os . path . join ( cwd , " toy_text " )
font = pygame . font . Font ( os . path . join ( cwd , path ) , size )
return font
small_font = get_font (
os . path . join ( " font " , " Minecraft.ttf " ) , screen_height / / 15
)
dealer_text = small_font . render (
" Dealer: " + str ( dealer_card_value ) , True , white
)
dealer_text_rect = screen . blit ( dealer_text , ( spacing , spacing ) )
def scale_card_img ( card_img ) :
return pygame . transform . scale ( card_img , ( card_img_width , card_img_height ) )
dealer_card_img = scale_card_img (
get_image (
os . path . join (
" img " ,
f " { dealer_top_card_suit } { display_card_value } .png " ,
)
)
)
dealer_card_rect = screen . blit (
dealer_card_img ,
(
screen_width / / 2 - card_img_width - spacing / / 2 ,
dealer_text_rect . bottom + spacing ,
) ,
)
hidden_card_img = scale_card_img ( get_image ( os . path . join ( " img " , " Card.png " ) ) )
screen . blit (
hidden_card_img ,
(
screen_width / / 2 + spacing / / 2 ,
dealer_text_rect . bottom + spacing ,
) ,
)
player_text = small_font . render ( " Player " , True , white )
player_text_rect = screen . blit (
player_text , ( spacing , dealer_card_rect . bottom + 1.5 * spacing )
)
large_font = get_font ( os . path . join ( " font " , " Minecraft.ttf " ) , screen_height / / 6 )
player_sum_text = large_font . render ( str ( player_sum ) , True , white )
player_sum_text_rect = screen . blit (
player_sum_text ,
(
screen_width / / 2 - player_sum_text . get_width ( ) / / 2 ,
player_text_rect . bottom + spacing ,
) ,
)
if usable_ace :
usable_ace_text = small_font . render ( " usable ace " , True , white )
screen . blit (
usable_ace_text ,
(
screen_width / / 2 - usable_ace_text . get_width ( ) / / 2 ,
player_sum_text_rect . bottom + spacing / / 2 ,
) ,
)
return render_state , np . transpose (
np . array ( pygame . surfarray . pixels3d ( screen ) ) , axes = ( 1 , 0 , 2 )
)
def render_close ( self , render_state : RenderStateType ) - > None :
""" Closes the render state. """
try :
import pygame
except ImportError as e :
raise DependencyNotInstalled (
" pygame is not installed, run `pip install gymnasium[classic_control]` "
) from e
pygame . display . quit ( )
pygame . quit ( )
2023-12-17 15:03:06 +01:00
def get_default_params ( self , * * kwargs ) - > BlackJackParams :
""" Get the default params. """
return BlackJackParams ( * * kwargs )
2023-02-21 05:38:11 -06:00
class BlackJackJaxEnv ( FunctionalJaxEnv , EzPickle ) :
""" A Gymnasium Env wrapper for the functional blackjack env. """
2024-04-05 18:21:10 +02:00
metadata = { " render_modes " : [ " rgb_array " ] , " render_fps " : 50 , " jax " : True }
2023-02-21 05:38:11 -06:00
def __init__ ( self , render_mode : Optional [ str ] = None , * * kwargs ) :
""" Initializes Gym wrapper for blackjack functional env. """
EzPickle . __init__ ( self , render_mode = render_mode , * * kwargs )
env = BlackjackFunctional ( * * kwargs )
env . transform ( jax . jit )
super ( ) . __init__ (
env ,
metadata = self . metadata ,
render_mode = render_mode ,
)
# Pixel art from Mariia Khmelnytska (https://www.123rf.com/photo_104453049_stock-vector-pixel-art-playing-cards-standart-deck-vector-set.html)
# Jax structure inspired by https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba
if __name__ == " __main__ " :
"""
Temporary environment tester function .
"""
env = HumanRendering ( BlackJackJaxEnv ( render_mode = " rgb_array " ) )
obs , info = env . reset ( )
print ( obs , info )
terminal = False
while not terminal :
action = int ( input ( " Please input an action \n " ) )
obs , reward , terminal , truncated , info = env . step ( action )
print ( obs , reward , terminal , truncated , info )
exit ( )