2023-02-12 07:49:37 -05:00
""" Wrapper that tracks the cumulative rewards and episode lengths. """
2024-06-10 17:07:47 +01:00
2023-05-23 15:46:04 +01:00
from __future__ import annotations
2023-02-12 07:49:37 -05:00
import time
from collections import deque
import numpy as np
2023-05-23 15:46:04 +01:00
from gymnasium . core import ActType , ObsType
2023-11-07 13:27:25 +00:00
from gymnasium . vector . vector_env import ArrayType , VectorEnv , VectorWrapper
2023-02-12 07:49:37 -05:00
2023-11-07 13:27:25 +00:00
__all__ = [ " RecordEpisodeStatistics " ]
2023-06-21 17:04:11 +01:00
2023-11-07 13:27:25 +00:00
class RecordEpisodeStatistics ( VectorWrapper ) :
2023-02-12 07:49:37 -05:00
""" This wrapper will keep track of cumulative rewards and episode lengths.
2023-11-07 13:27:25 +00:00
At the end of any episode within the vectorized env , the statistics of the episode
will be added to ` ` info ` ` using the key ` ` episode ` ` , and the ` ` _episode ` ` key
is used to indicate the environment index which has a terminated or truncated episode .
2023-02-12 07:49:37 -05:00
>> > infos = { # doctest: +SKIP
. . . . . .
. . . " episode " : {
2023-05-23 15:46:04 +01:00
. . . " r " : " <array of cumulative reward for each done sub-environment> " ,
. . . " l " : " <array of episode length for each done sub-environment> " ,
. . . " t " : " <array of elapsed time since beginning of episode for each done sub-environment> "
2023-02-12 07:49:37 -05:00
. . . } ,
. . . " _episode " : " <boolean array of length num-envs> "
. . . }
Moreover , the most recent rewards and episode lengths are stored in buffers that can be accessed via
: attr : ` wrapped_env . return_queue ` and : attr : ` wrapped_env . length_queue ` respectively .
Attributes :
return_queue : The cumulative rewards of the last ` ` deque_size ` ` - many episodes
length_queue : The lengths of the last ` ` deque_size ` ` - many episodes
2023-11-07 13:27:25 +00:00
Example :
>> > from pprint import pprint
>> > import gymnasium as gym
>> > envs = gym . make_vec ( " CartPole-v1 " , num_envs = 3 )
>> > envs = RecordEpisodeStatistics ( envs )
>> > obs , info = envs . reset ( 123 )
>> > _ = envs . action_space . seed ( 123 )
>> > end = False
>> > while not end :
. . . obs , rew , term , trunc , info = envs . step ( envs . action_space . sample ( ) )
. . . end = term . any ( ) or trunc . any ( )
. . .
>> > envs . close ( )
>> > pprint ( info ) # doctest: +SKIP
{ ' _episode ' : array ( [ True , False , False ] ) ,
' _final_info ' : array ( [ True , False , False ] ) ,
' _final_observation ' : array ( [ True , False , False ] ) ,
' episode ' : { ' l ' : array ( [ 11 , 0 , 0 ] , dtype = int32 ) ,
' r ' : array ( [ 11. , 0. , 0. ] , dtype = float32 ) ,
' t ' : array ( [ 0.007812 , 0. , 0. ] , dtype = float32 ) } ,
' final_info ' : array ( [ { } , None , None ] , dtype = object ) ,
' final_observation ' : array ( [ array ( [ 0.11448676 , 0.9416149 , - 0.20946532 , - 1.7619033 ] , dtype = float32 ) ,
None , None ] , dtype = object ) }
2023-02-12 07:49:37 -05:00
"""
2023-12-03 19:50:18 +01:00
def __init__ (
self ,
env : VectorEnv ,
2023-12-08 12:46:40 +00:00
buffer_length : int = 100 ,
2023-12-03 19:50:18 +01:00
stats_key : str = " episode " ,
) :
2023-02-12 07:49:37 -05:00
""" This wrapper will keep track of cumulative rewards and episode lengths.
Args :
env ( Env ) : The environment to apply the wrapper
2023-12-08 12:46:40 +00:00
buffer_length : The size of the buffers : attr : ` return_queue ` , : attr : ` length_queue ` and : attr : ` time_queue `
2023-12-03 19:50:18 +01:00
stats_key : The info key to save the data
2023-02-12 07:49:37 -05:00
"""
super ( ) . __init__ ( env )
2023-12-03 19:50:18 +01:00
self . _stats_key = stats_key
2023-05-23 15:46:04 +01:00
2023-02-12 07:49:37 -05:00
self . episode_count = 0
2023-05-23 15:46:04 +01:00
self . episode_start_times : np . ndarray = np . zeros ( ( ) )
self . episode_returns : np . ndarray = np . zeros ( ( ) )
2024-04-18 17:48:13 +02:00
self . episode_lengths : np . ndarray = np . zeros ( ( ) , dtype = int )
self . prev_dones : np . ndarray = np . zeros ( ( ) , dtype = bool )
2023-05-23 15:46:04 +01:00
2023-12-08 12:46:40 +00:00
self . time_queue = deque ( maxlen = buffer_length )
self . return_queue = deque ( maxlen = buffer_length )
self . length_queue = deque ( maxlen = buffer_length )
2023-02-12 07:49:37 -05:00
def reset (
self ,
2023-05-23 15:46:04 +01:00
seed : int | list [ int ] | None = None ,
options : dict | None = None ,
2023-02-12 07:49:37 -05:00
) :
""" Resets the environment using kwargs and resets the episode returns and lengths. """
obs , info = super ( ) . reset ( seed = seed , options = options )
2023-05-23 15:46:04 +01:00
2023-12-03 19:50:18 +01:00
self . episode_start_times = np . full ( self . num_envs , time . perf_counter ( ) )
self . episode_returns = np . zeros ( self . num_envs )
2024-04-18 17:48:13 +02:00
self . episode_lengths = np . zeros ( self . num_envs , dtype = int )
self . prev_dones = np . zeros ( self . num_envs , dtype = bool )
2023-05-23 15:46:04 +01:00
2023-02-12 07:49:37 -05:00
return obs , info
2023-05-23 15:46:04 +01:00
def step (
self , actions : ActType
) - > tuple [ ObsType , ArrayType , ArrayType , ArrayType , dict ] :
2023-02-12 07:49:37 -05:00
""" Steps through the environment, recording the episode statistics. """
(
observations ,
rewards ,
terminations ,
truncations ,
infos ,
2023-05-23 15:46:04 +01:00
) = self . env . step ( actions )
2023-02-12 07:49:37 -05:00
assert isinstance (
infos , dict
2023-12-03 19:50:18 +01:00
) , f " `vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is { type ( infos ) } . This may be due to usage of other wrappers in the wrong order. "
2023-05-23 15:46:04 +01:00
2024-04-18 17:48:13 +02:00
self . episode_returns [ self . prev_dones ] = 0
self . episode_lengths [ self . prev_dones ] = 0
self . episode_start_times [ self . prev_dones ] = time . perf_counter ( )
self . episode_returns [ ~ self . prev_dones ] + = rewards [ ~ self . prev_dones ]
self . episode_lengths [ ~ self . prev_dones ] + = 1
2023-05-23 15:46:04 +01:00
2024-04-18 17:48:13 +02:00
self . prev_dones = dones = np . logical_or ( terminations , truncations )
2023-02-12 07:49:37 -05:00
num_dones = np . sum ( dones )
2023-05-23 15:46:04 +01:00
2023-02-12 07:49:37 -05:00
if num_dones :
2023-12-03 19:50:18 +01:00
if self . _stats_key in infos or f " _ { self . _stats_key } " in infos :
2023-02-12 07:49:37 -05:00
raise ValueError (
2023-12-03 19:50:18 +01:00
f " Attempted to add episode stats when they already exist, info keys: { list ( infos . keys ( ) ) } "
2023-02-12 07:49:37 -05:00
)
else :
2023-12-03 19:50:18 +01:00
episode_time_length = np . round (
time . perf_counter ( ) - self . episode_start_times , 6
)
infos [ self . _stats_key ] = {
2023-02-12 07:49:37 -05:00
" r " : np . where ( dones , self . episode_returns , 0.0 ) ,
" l " : np . where ( dones , self . episode_lengths , 0 ) ,
2023-12-03 19:50:18 +01:00
" t " : np . where ( dones , episode_time_length , 0.0 ) ,
2023-02-12 07:49:37 -05:00
}
2023-12-03 19:50:18 +01:00
infos [ f " _ { self . _stats_key } " ] = dones
2023-05-23 15:46:04 +01:00
2023-02-12 07:49:37 -05:00
self . episode_count + = num_dones
2023-05-23 15:46:04 +01:00
for i in np . where ( dones ) :
2023-12-03 19:50:18 +01:00
self . time_queue . extend ( episode_time_length [ i ] )
2023-05-23 15:46:04 +01:00
self . return_queue . extend ( self . episode_returns [ i ] )
self . length_queue . extend ( self . episode_lengths [ i ] )
2023-02-12 07:49:37 -05:00
return (
observations ,
rewards ,
terminations ,
truncations ,
infos ,
)