2016-04-27 08:00:58 -07:00
import json
import os
import time
from gym import error
2016-05-06 18:19:16 -07:00
from gym . utils import atomic_write
2016-04-27 08:00:58 -07:00
class StatsRecorder ( object ) :
2016-10-23 14:05:42 -07:00
def __init__ ( self , directory , file_prefix , autoreset = False , env_id = None ) :
2016-10-23 10:35:24 -07:00
self . autoreset = autoreset
2016-10-23 14:05:42 -07:00
self . env_id = env_id
2016-10-23 10:35:24 -07:00
2016-04-27 09:17:05 -07:00
self . initial_reset_timestamp = None
2016-04-27 08:00:58 -07:00
self . directory = directory
self . file_prefix = file_prefix
self . episode_lengths = [ ]
self . episode_rewards = [ ]
2016-09-23 01:04:26 -07:00
self . episode_types = [ ] # experimental addition
self . _type = ' t '
2016-04-27 08:00:58 -07:00
self . timestamps = [ ]
self . steps = None
2016-12-14 13:55:18 -08:00
self . total_steps = 0
2016-04-27 08:00:58 -07:00
self . rewards = None
self . done = None
2016-05-06 22:00:29 -07:00
self . closed = False
2016-05-29 09:07:09 -07:00
filename = ' {} .stats.json ' . format ( self . file_prefix )
2016-05-06 22:00:29 -07:00
self . path = os . path . join ( self . directory , filename )
2016-04-27 08:00:58 -07:00
2016-09-23 01:04:26 -07:00
@property
def type ( self ) :
return self . _type
@type.setter
def type ( self , type ) :
if type not in [ ' t ' , ' e ' ] :
raise error . Error ( ' Invalid episode type {} : must be t for training or e for evaluation ' , type )
self . _type = type
2016-04-27 08:00:58 -07:00
def before_step ( self , action ) :
2016-05-06 22:00:29 -07:00
assert not self . closed
2016-04-27 08:00:58 -07:00
if self . done :
2016-10-23 14:05:42 -07:00
raise error . ResetNeeded ( " Trying to step environment which is currently done. While the monitor is active for {} , you cannot step beyond the end of an episode. Call ' env.reset() ' to start the next episode. " . format ( self . env_id ) )
2016-04-27 08:00:58 -07:00
elif self . steps is None :
2016-10-23 14:05:42 -07:00
raise error . ResetNeeded ( " Trying to step an environment before reset. While the monitor is active for {} , you must call ' env.reset() ' before taking an initial step. " . format ( self . env_id ) )
2016-04-27 08:00:58 -07:00
def after_step ( self , observation , reward , done , info ) :
self . steps + = 1
2016-12-14 13:55:18 -08:00
self . total_steps + = 1
2016-04-27 08:00:58 -07:00
self . rewards + = reward
2016-11-02 12:56:38 -07:00
self . done = done
2016-10-31 20:06:29 -07:00
if done :
self . save_complete ( )
2016-04-27 08:00:58 -07:00
if done :
2016-10-23 10:35:24 -07:00
if self . autoreset :
self . before_reset ( )
self . after_reset ( observation )
2016-04-27 08:00:58 -07:00
def before_reset ( self ) :
2016-05-06 22:00:29 -07:00
assert not self . closed
2016-11-02 12:56:38 -07:00
if self . done is not None and not self . done and self . steps > 0 :
2016-10-31 19:23:06 -07:00
raise error . Error ( " Tried to reset environment which is not done. While the monitor is active for {} , you cannot call reset() unless the episode is over. " . format ( self . env_id ) )
2016-04-27 08:00:58 -07:00
self . done = False
2016-04-27 09:17:05 -07:00
if self . initial_reset_timestamp is None :
self . initial_reset_timestamp = time . time ( )
2016-04-27 08:00:58 -07:00
def after_reset ( self , observation ) :
2016-05-07 10:38:42 +10:00
self . steps = 0
self . rewards = 0
2016-09-23 01:04:26 -07:00
# We write the type at the beginning of the episode. If a user
# changes the type, it's more natural for it to apply next
# time the user calls reset().
self . episode_types . append ( self . _type )
2016-04-27 08:00:58 -07:00
2016-05-06 18:19:16 -07:00
def save_complete ( self ) :
2016-04-27 08:00:58 -07:00
if self . steps is not None :
self . episode_lengths . append ( self . steps )
self . episode_rewards . append ( self . rewards )
self . timestamps . append ( time . time ( ) )
2016-05-06 22:00:29 -07:00
def close ( self ) :
self . flush ( )
self . closed = True
2016-05-06 18:19:16 -07:00
def flush ( self ) :
2016-05-06 22:00:29 -07:00
if self . closed :
return
2016-05-07 10:38:42 +10:00
2016-05-06 22:00:29 -07:00
with atomic_write . atomic_write ( self . path ) as f :
2016-04-27 08:00:58 -07:00
json . dump ( {
2016-04-27 09:17:05 -07:00
' initial_reset_timestamp ' : self . initial_reset_timestamp ,
2016-04-27 08:00:58 -07:00
' timestamps ' : self . timestamps ,
' episode_lengths ' : self . episode_lengths ,
' episode_rewards ' : self . episode_rewards ,
2016-09-23 01:04:26 -07:00
' episode_types ' : self . episode_types ,
2016-04-27 08:00:58 -07:00
} , f )