baselines: default logger similar to configure() logger, rcall: don't call logger.configure() for new rl_algs
* error if logger looks wrong * check version of logger, call logger.configure() on import * remove changes entry * add version to rl-algs * fix typo * add comment * switch version to string * set logger env variable
This commit is contained in:
committed by
Peter Zhokhov
parent
b29c8020d7
commit
7bccb2969f
@@ -344,8 +344,6 @@ class Logger(object):
|
|||||||
if isinstance(fmt, SeqWriter):
|
if isinstance(fmt, SeqWriter):
|
||||||
fmt.writeseq(map(str, args))
|
fmt.writeseq(map(str, args))
|
||||||
|
|
||||||
Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)])
|
|
||||||
|
|
||||||
def configure(dir=None, format_strs=None):
|
def configure(dir=None, format_strs=None):
|
||||||
if dir is None:
|
if dir is None:
|
||||||
dir = os.getenv('OPENAI_LOGDIR')
|
dir = os.getenv('OPENAI_LOGDIR')
|
||||||
@@ -356,8 +354,12 @@ def configure(dir=None, format_strs=None):
|
|||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
|
||||||
log_suffix = ''
|
log_suffix = ''
|
||||||
from mpi4py import MPI
|
rank = 0
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
# check environment variables here instead of importing mpi4py
|
||||||
|
# to avoid calling MPI_Init() when this module is imported
|
||||||
|
for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']:
|
||||||
|
if varname in os.environ:
|
||||||
|
rank = int(os.environ[varname])
|
||||||
if rank > 0:
|
if rank > 0:
|
||||||
log_suffix = "-rank%03i" % rank
|
log_suffix = "-rank%03i" % rank
|
||||||
|
|
||||||
@@ -372,6 +374,14 @@ def configure(dir=None, format_strs=None):
|
|||||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||||
log('Logging to %s'%dir)
|
log('Logging to %s'%dir)
|
||||||
|
|
||||||
|
def _configure_default_logger():
|
||||||
|
format_strs = None
|
||||||
|
# keep the old default of only writing to stdout
|
||||||
|
if 'OPENAI_LOG_FORMAT' not in os.environ:
|
||||||
|
format_strs = ['stdout']
|
||||||
|
configure(format_strs=format_strs)
|
||||||
|
Logger.DEFAULT = Logger.CURRENT
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
if Logger.CURRENT is not Logger.DEFAULT:
|
if Logger.CURRENT is not Logger.DEFAULT:
|
||||||
Logger.CURRENT.close()
|
Logger.CURRENT.close()
|
||||||
@@ -471,5 +481,8 @@ def read_tb(path):
|
|||||||
data[step-1, colidx] = value
|
data[step-1, colidx] = value
|
||||||
return pandas.DataFrame(data, columns=tags)
|
return pandas.DataFrame(data, columns=tags)
|
||||||
|
|
||||||
|
# configure the default logger on import
|
||||||
|
_configure_default_logger()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_demo()
|
_demo()
|
||||||
|
Reference in New Issue
Block a user