From 7bccb2969fa860e5a62a9ee26d3bb57d736e32d9 Mon Sep 17 00:00:00 2001 From: Christopher Hesse Date: Thu, 30 Aug 2018 15:04:40 -0700 Subject: [PATCH] baselines: default logger similar to configure() logger, rcall: don't call logger.configure() for new rl_algs * error if logger looks wrong * check version of logger, call logger.configure() on import * remove changes entry * add version to rl-algs * fix typo * add comment * switch version to string * set logger env variable --- baselines/logger.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/baselines/logger.py b/baselines/logger.py index 0abad0e..be38f43 100644 --- a/baselines/logger.py +++ b/baselines/logger.py @@ -344,8 +344,6 @@ class Logger(object): if isinstance(fmt, SeqWriter): fmt.writeseq(map(str, args)) -Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)]) - def configure(dir=None, format_strs=None): if dir is None: dir = os.getenv('OPENAI_LOGDIR') @@ -356,8 +354,12 @@ def configure(dir=None, format_strs=None): os.makedirs(dir, exist_ok=True) log_suffix = '' - from mpi4py import MPI - rank = MPI.COMM_WORLD.Get_rank() + rank = 0 + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']: + if varname in os.environ: + rank = int(os.environ[varname]) if rank > 0: log_suffix = "-rank%03i" % rank @@ -372,6 +374,14 @@ def configure(dir=None, format_strs=None): Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) log('Logging to %s'%dir) +def _configure_default_logger(): + format_strs = None + # keep the old default of only writing to stdout + if 'OPENAI_LOG_FORMAT' not in os.environ: + format_strs = ['stdout'] + configure(format_strs=format_strs) + Logger.DEFAULT = Logger.CURRENT + def reset(): if Logger.CURRENT is not Logger.DEFAULT: Logger.CURRENT.close() @@ -471,5 +481,8 @@ def read_tb(path): data[step-1, colidx] = value return pandas.DataFrame(data, columns=tags) +# configure the default logger on import +_configure_default_logger() + if __name__ == "__main__": _demo()