delayed logger configuration (#208)

* delayed logger configuration

* fix typo

* setters and getters for Logger.DEFAULT as well

* do away with fancy property stuff - unable to get it to work with class level methods

* grammar and spaces

* spaces

* use get_current function instead of reading Logger.CURRENT

* autopep8
This commit is contained in:
pzhokhov
2019-01-07 11:07:19 -08:00
committed by Peter Zhokhov
parent 370ee27750
commit 3a8f35a7e9
4 changed files with 24 additions and 20 deletions

View File

@@ -30,6 +30,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
wrapper_kwargs = wrapper_kwargs or {}
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
seed = seed + 10000 * mpi_rank if seed is not None else None
logger_dir = logger.get_dir()
def make_thunk(rank):
return lambda: make_env(
env_id=env_id,
@@ -39,7 +40,8 @@ def make_vec_env(env_id, env_type, num_env, seed,
reward_scale=reward_scale,
gamestate=gamestate,
flatten_dict_observations=flatten_dict_observations,
wrapper_kwargs=wrapper_kwargs
wrapper_kwargs=wrapper_kwargs,
logger_dir=logger_dir
)
set_global_seeds(seed)
@@ -49,7 +51,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
return DummyVecEnv([make_thunk(start_index)])
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None):
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
wrapper_kwargs = wrapper_kwargs or {}
if env_type == 'atari':
@@ -67,7 +69,7 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate
env.seed(seed + subrank if seed is not None else None)
env = Monitor(env,
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
allow_early_resets=True)
if env_type == 'atari':

View File

@@ -126,4 +126,4 @@ def mpi_weighted_mean(comm, local_name2valcount):
name2count[name] += count
return {name : name2sum[name] / name2count[name] for name in name2sum}
else:
return {}
return {}

View File

@@ -196,13 +196,13 @@ def logkv(key, val):
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
"""
Logger.CURRENT.logkv(key, val)
get_current().logkv(key, val)
def logkv_mean(key, val):
"""
The same as logkv(), but if called many times, values averaged.
"""
Logger.CURRENT.logkv_mean(key, val)
get_current().logkv_mean(key, val)
def logkvs(d):
"""
@@ -221,17 +221,17 @@ def dumpkvs(mpi_mean=False):
the root worker collect all of the stats and write the average,
and no one else writes anything.
"""
return Logger.CURRENT.dumpkvs(mpi_mean=mpi_mean)
return get_current().dumpkvs(mpi_mean=mpi_mean)
def getkvs():
return Logger.CURRENT.name2val
return get_current().name2val
def log(*args, level=INFO):
"""
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
"""
Logger.CURRENT.log(*args, level=level)
get_current().log(*args, level=level)
def debug(*args):
log(*args, level=DEBUG)
@@ -250,14 +250,14 @@ def set_level(level):
"""
Set logging threshold on current logger.
"""
Logger.CURRENT.set_level(level)
get_current().set_level(level)
def get_dir():
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
"""
return Logger.CURRENT.get_dir()
return get_current().get_dir()
record_tabular = logkv
dump_tabular = dumpkvs
@@ -273,7 +273,7 @@ class ProfileKV:
def __enter__(self):
self.t1 = time.time()
def __exit__(self ,type, value, traceback):
Logger.CURRENT.name2val[self.n] += time.time() - self.t1
get_current().name2val[self.n] += time.time() - self.t1
def profile(n):
"""
@@ -293,6 +293,13 @@ def profile(n):
# Backend
# ================================================================
def get_current():
if Logger.CURRENT is None:
_configure_default_logger()
return Logger.CURRENT
class Logger(object):
DEFAULT = None # A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output files
@@ -409,7 +416,7 @@ class scoped_configure(object):
self.format_strs = format_strs
self.prevlogger = None
def __enter__(self):
self.prevlogger = Logger.CURRENT
self.prevlogger = get_current()
configure(dir=self.dir, format_strs=self.format_strs)
def __exit__(self, *args):
Logger.CURRENT.close()
@@ -437,7 +444,7 @@ def _demo():
logkv_mean("b", -44.4)
logkv("a", 5.5)
dumpkvs()
info("^^^ should see b = 33.3")
info("^^^ should see b = -33.3")
logkv("b", -2.5)
dumpkvs()
@@ -495,10 +502,5 @@ def read_tb(path):
data[step-1, colidx] = value
return pandas.DataFrame(data, columns=tags)
# configure the default logger on import
_configure_default_logger()
if __name__ == "__main__":
_demo()

View File

@@ -64,7 +64,7 @@ def train(args, extra_args):
env = build_env(args)
if args.save_video_interval != 0:
env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
if args.network:
alg_kwargs['network'] = args.network