delayed logger configuration (#208)
* delayed logger configuration * fix typo * setters and getters for Logger.DEFAULT as well * do away with fancy property stuff - unable to get it to work with class level methods * grammar and spaces * spaces * use get_current function instead of reading Logger.CURRENT * autopep8
This commit is contained in:
@@ -30,6 +30,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
seed = seed + 10000 * mpi_rank if seed is not None else None
|
||||
logger_dir = logger.get_dir()
|
||||
def make_thunk(rank):
|
||||
return lambda: make_env(
|
||||
env_id=env_id,
|
||||
@@ -39,7 +40,8 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
reward_scale=reward_scale,
|
||||
gamestate=gamestate,
|
||||
flatten_dict_observations=flatten_dict_observations,
|
||||
wrapper_kwargs=wrapper_kwargs
|
||||
wrapper_kwargs=wrapper_kwargs,
|
||||
logger_dir=logger_dir
|
||||
)
|
||||
|
||||
set_global_seeds(seed)
|
||||
@@ -49,7 +51,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
return DummyVecEnv([make_thunk(start_index)])
|
||||
|
||||
|
||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None):
|
||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
if env_type == 'atari':
|
||||
@@ -67,7 +69,7 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate
|
||||
|
||||
env.seed(seed + subrank if seed is not None else None)
|
||||
env = Monitor(env,
|
||||
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
|
||||
if env_type == 'atari':
|
||||
|
@@ -126,4 +126,4 @@ def mpi_weighted_mean(comm, local_name2valcount):
|
||||
name2count[name] += count
|
||||
return {name : name2sum[name] / name2count[name] for name in name2sum}
|
||||
else:
|
||||
return {}
|
||||
return {}
|
||||
|
@@ -196,13 +196,13 @@ def logkv(key, val):
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
If called many times, last value will be used.
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
get_current().logkv(key, val)
|
||||
|
||||
def logkv_mean(key, val):
|
||||
"""
|
||||
The same as logkv(), but if called many times, values averaged.
|
||||
"""
|
||||
Logger.CURRENT.logkv_mean(key, val)
|
||||
get_current().logkv_mean(key, val)
|
||||
|
||||
def logkvs(d):
|
||||
"""
|
||||
@@ -221,17 +221,17 @@ def dumpkvs(mpi_mean=False):
|
||||
the root worker collect all of the stats and write the average,
|
||||
and no one else writes anything.
|
||||
"""
|
||||
return Logger.CURRENT.dumpkvs(mpi_mean=mpi_mean)
|
||||
return get_current().dumpkvs(mpi_mean=mpi_mean)
|
||||
|
||||
def getkvs():
|
||||
return Logger.CURRENT.name2val
|
||||
return get_current().name2val
|
||||
|
||||
|
||||
def log(*args, level=INFO):
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
||||
"""
|
||||
Logger.CURRENT.log(*args, level=level)
|
||||
get_current().log(*args, level=level)
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
@@ -250,14 +250,14 @@ def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
Logger.CURRENT.set_level(level)
|
||||
get_current().set_level(level)
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
"""
|
||||
return Logger.CURRENT.get_dir()
|
||||
return get_current().get_dir()
|
||||
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
@@ -273,7 +273,7 @@ class ProfileKV:
|
||||
def __enter__(self):
|
||||
self.t1 = time.time()
|
||||
def __exit__(self ,type, value, traceback):
|
||||
Logger.CURRENT.name2val[self.n] += time.time() - self.t1
|
||||
get_current().name2val[self.n] += time.time() - self.t1
|
||||
|
||||
def profile(n):
|
||||
"""
|
||||
@@ -293,6 +293,13 @@ def profile(n):
|
||||
# Backend
|
||||
# ================================================================
|
||||
|
||||
def get_current():
|
||||
if Logger.CURRENT is None:
|
||||
_configure_default_logger()
|
||||
|
||||
return Logger.CURRENT
|
||||
|
||||
|
||||
class Logger(object):
|
||||
DEFAULT = None # A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output files
|
||||
@@ -409,7 +416,7 @@ class scoped_configure(object):
|
||||
self.format_strs = format_strs
|
||||
self.prevlogger = None
|
||||
def __enter__(self):
|
||||
self.prevlogger = Logger.CURRENT
|
||||
self.prevlogger = get_current()
|
||||
configure(dir=self.dir, format_strs=self.format_strs)
|
||||
def __exit__(self, *args):
|
||||
Logger.CURRENT.close()
|
||||
@@ -437,7 +444,7 @@ def _demo():
|
||||
logkv_mean("b", -44.4)
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
info("^^^ should see b = 33.3")
|
||||
info("^^^ should see b = -33.3")
|
||||
|
||||
logkv("b", -2.5)
|
||||
dumpkvs()
|
||||
@@ -495,10 +502,5 @@ def read_tb(path):
|
||||
data[step-1, colidx] = value
|
||||
return pandas.DataFrame(data, columns=tags)
|
||||
|
||||
# configure the default logger on import
|
||||
_configure_default_logger()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_demo()
|
||||
|
@@ -64,7 +64,7 @@ def train(args, extra_args):
|
||||
|
||||
env = build_env(args)
|
||||
if args.save_video_interval != 0:
|
||||
env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
|
||||
env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
|
||||
|
||||
if args.network:
|
||||
alg_kwargs['network'] = args.network
|
||||
|
Reference in New Issue
Block a user