Everyrl initial commit & a few minor baselines changes (#226)
* everyrl initial commit * add keep_buf argument to VecMonitor * logger changes: set_comm and fix to mpi_mean functionality * if filename not provided, don't create ResultsWriter * change variable syncing function to simplify its usage. now you should initialize from all mpi processes * everyrl coinrun changes * tf_distr changes, bugfix * get_one * bring back get_next to temporarily restore code * lint fixes * fix test * rename profile function * rename gaussian * fix coinrun training script
This commit is contained in:
committed by
Peter Zhokhov
parent
cd8d3389ba
commit
82ebd4a153
@@ -16,11 +16,13 @@ class Monitor(Wrapper):
|
||||
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
|
||||
Wrapper.__init__(self, env=env)
|
||||
self.tstart = time.time()
|
||||
self.results_writer = ResultsWriter(
|
||||
filename,
|
||||
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
|
||||
extra_keys=reset_keywords + info_keywords
|
||||
)
|
||||
if filename:
|
||||
self.results_writer = ResultsWriter(filename,
|
||||
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
|
||||
extra_keys=reset_keywords + info_keywords
|
||||
)
|
||||
else:
|
||||
self.results_writer = None
|
||||
self.reset_keywords = reset_keywords
|
||||
self.info_keywords = info_keywords
|
||||
self.allow_early_resets = allow_early_resets
|
||||
@@ -68,8 +70,9 @@ class Monitor(Wrapper):
|
||||
self.episode_lengths.append(eplen)
|
||||
self.episode_times.append(time.time() - self.tstart)
|
||||
epinfo.update(self.current_reset_info)
|
||||
self.results_writer.write_row(epinfo)
|
||||
|
||||
if self.results_writer:
|
||||
self.results_writer.write_row(epinfo)
|
||||
assert isinstance(info, dict)
|
||||
if isinstance(info, dict):
|
||||
info['episode'] = epinfo
|
||||
|
||||
@@ -96,24 +99,21 @@ class LoadMonitorResultsError(Exception):
|
||||
|
||||
|
||||
class ResultsWriter(object):
|
||||
def __init__(self, filename=None, header='', extra_keys=()):
|
||||
def __init__(self, filename, header='', extra_keys=()):
|
||||
self.extra_keys = extra_keys
|
||||
if filename is None:
|
||||
self.f = None
|
||||
self.logger = None
|
||||
else:
|
||||
if not filename.endswith(Monitor.EXT):
|
||||
if osp.isdir(filename):
|
||||
filename = osp.join(filename, Monitor.EXT)
|
||||
else:
|
||||
filename = filename + "." + Monitor.EXT
|
||||
self.f = open(filename, "wt")
|
||||
if isinstance(header, dict):
|
||||
header = '# {} \n'.format(json.dumps(header))
|
||||
self.f.write(header)
|
||||
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
|
||||
self.logger.writeheader()
|
||||
self.f.flush()
|
||||
assert filename is not None
|
||||
if not filename.endswith(Monitor.EXT):
|
||||
if osp.isdir(filename):
|
||||
filename = osp.join(filename, Monitor.EXT)
|
||||
else:
|
||||
filename = filename + "." + Monitor.EXT
|
||||
self.f = open(filename, "wt")
|
||||
if isinstance(header, dict):
|
||||
header = '# {} \n'.format(json.dumps(header))
|
||||
self.f.write(header)
|
||||
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
|
||||
self.logger.writeheader()
|
||||
self.f.flush()
|
||||
|
||||
def write_row(self, epinfo):
|
||||
if self.logger:
|
||||
@@ -121,7 +121,6 @@ class ResultsWriter(object):
|
||||
self.f.flush()
|
||||
|
||||
|
||||
|
||||
def get_monitor_files(dir):
|
||||
return glob(osp.join(dir, "*" + Monitor.EXT))
|
||||
|
||||
|
@@ -19,15 +19,10 @@ def sync_from_root(sess, variables, comm=None):
|
||||
variables: all parameter variables including optimizer's
|
||||
"""
|
||||
if comm is None: comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
for var in variables:
|
||||
if rank == 0:
|
||||
comm.Bcast(sess.run(var))
|
||||
else:
|
||||
import tensorflow as tf
|
||||
returned_var = np.empty(var.shape, dtype='float32')
|
||||
comm.Bcast(returned_var)
|
||||
sess.run(tf.assign(var, returned_var))
|
||||
import tensorflow as tf
|
||||
values = comm.bcast(sess.run(variables))
|
||||
sess.run([tf.assign(var, val)
|
||||
for (var, val) in zip(variables, values)])
|
||||
|
||||
def gpu_count():
|
||||
"""
|
||||
|
@@ -6,21 +6,22 @@ from baselines.common import mpi_util
|
||||
def test_mpi_weighted_mean():
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
if comm.rank == 0:
|
||||
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
|
||||
elif comm.rank == 1:
|
||||
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
with logger.scoped_configure(comm=comm):
|
||||
if comm.rank == 0:
|
||||
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
|
||||
elif comm.rank == 1:
|
||||
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
|
||||
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
|
||||
if comm.rank == 0:
|
||||
assert d == correctval, f'{d} != {correctval}'
|
||||
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
|
||||
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
|
||||
if comm.rank == 0:
|
||||
assert d == correctval, f'{d} != {correctval}'
|
||||
|
||||
for name, (val, count) in name2valcount.items():
|
||||
for _ in range(count):
|
||||
logger.logkv_mean(name, val)
|
||||
d2 = logger.dumpkvs(mpi_mean=True)
|
||||
if comm.rank == 0:
|
||||
assert d2 == correctval
|
||||
for name, (val, count) in name2valcount.items():
|
||||
for _ in range(count):
|
||||
logger.logkv_mean(name, val)
|
||||
d2 = logger.dumpkvs()
|
||||
if comm.rank == 0:
|
||||
assert d2 == correctval
|
||||
|
@@ -2,15 +2,22 @@ from . import VecEnvWrapper
|
||||
from baselines.bench.monitor import ResultsWriter
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from collections import deque
|
||||
|
||||
class VecMonitor(VecEnvWrapper):
|
||||
def __init__(self, venv, filename=None):
|
||||
def __init__(self, venv, filename=None, keep_buf=0):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.eprets = None
|
||||
self.eplens = None
|
||||
self.tstart = time.time()
|
||||
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
|
||||
if filename:
|
||||
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
|
||||
else:
|
||||
self.results_writer = None
|
||||
self.keep_buf = keep_buf
|
||||
if self.keep_buf:
|
||||
self.epret_buf = deque([], maxlen=keep_buf)
|
||||
self.eplen_buf = deque([], maxlen=keep_buf)
|
||||
|
||||
def reset(self):
|
||||
obs = self.venv.reset()
|
||||
@@ -28,10 +35,13 @@ class VecMonitor(VecEnvWrapper):
|
||||
if done:
|
||||
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
|
||||
info['episode'] = epinfo
|
||||
if self.keep_buf:
|
||||
self.epret_buf.append(ret)
|
||||
self.eplen_buf.append(eplen)
|
||||
self.eprets[i] = 0
|
||||
self.eplens[i] = 0
|
||||
self.results_writer.write_row(epinfo)
|
||||
|
||||
if self.results_writer:
|
||||
self.results_writer.write_row(epinfo)
|
||||
newinfos.append(info)
|
||||
|
||||
return obs, rews, dones, newinfos
|
||||
|
@@ -7,6 +7,7 @@ import time
|
||||
import datetime
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
@@ -211,17 +212,11 @@ def logkvs(d):
|
||||
for (k, v) in d.items():
|
||||
logkv(k, v)
|
||||
|
||||
def dumpkvs(mpi_mean=False):
|
||||
def dumpkvs():
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
|
||||
mpi_mean: whether to average across MPI workers. mpi_mean=False just
|
||||
has each worker write its own stats (and under default settings
|
||||
non-root workers don't write anything), whereas mpi_mean=True has
|
||||
the root worker collect all of the stats and write the average,
|
||||
and no one else writes anything.
|
||||
"""
|
||||
return get_current().dumpkvs(mpi_mean=mpi_mean)
|
||||
return get_current().dumpkvs()
|
||||
|
||||
def getkvs():
|
||||
return get_current().name2val
|
||||
@@ -252,6 +247,9 @@ def set_level(level):
|
||||
"""
|
||||
get_current().set_level(level)
|
||||
|
||||
def set_comm(comm):
|
||||
get_current().set_comm(comm)
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
@@ -262,18 +260,14 @@ def get_dir():
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
|
||||
class ProfileKV:
|
||||
"""
|
||||
Usage:
|
||||
with logger.ProfileKV("interesting_scope"):
|
||||
code
|
||||
"""
|
||||
def __init__(self, n):
|
||||
self.n = "wait_" + n
|
||||
def __enter__(self):
|
||||
self.t1 = time.time()
|
||||
def __exit__(self ,type, value, traceback):
|
||||
get_current().name2val[self.n] += time.time() - self.t1
|
||||
@contextmanager
|
||||
def profile_kv(scopename):
|
||||
logkey = 'wait_' + scopename
|
||||
tstart = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_current().name2val[logkey] += time.time() - tstart
|
||||
|
||||
def profile(n):
|
||||
"""
|
||||
@@ -283,7 +277,7 @@ def profile(n):
|
||||
"""
|
||||
def decorator_with_name(func):
|
||||
def func_wrapper(*args, **kwargs):
|
||||
with ProfileKV(n):
|
||||
with profile_kv(n):
|
||||
return func(*args, **kwargs)
|
||||
return func_wrapper
|
||||
return decorator_with_name
|
||||
@@ -305,12 +299,13 @@ class Logger(object):
|
||||
# So that you can still log to the terminal without setting up any output files
|
||||
CURRENT = None # Current logger being used by the free functions above
|
||||
|
||||
def __init__(self, dir, output_formats):
|
||||
def __init__(self, dir, output_formats, comm=None):
|
||||
self.name2val = defaultdict(float) # values this iteration
|
||||
self.name2cnt = defaultdict(int)
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
self.comm = comm
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
@@ -322,19 +317,16 @@ class Logger(object):
|
||||
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
|
||||
self.name2cnt[key] = cnt + 1
|
||||
|
||||
def dumpkvs(self, mpi_mean=False):
|
||||
if self.level == DISABLED: return
|
||||
if mpi_mean:
|
||||
def dumpkvs(self):
|
||||
if self.comm is None:
|
||||
d = self.name2val
|
||||
else:
|
||||
from baselines.common import mpi_util
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
d = mpi_util.mpi_weighted_mean(comm,
|
||||
d = mpi_util.mpi_weighted_mean(self.comm,
|
||||
{name : (val, self.name2cnt.get(name, 1))
|
||||
for (name, val) in self.name2val.items()})
|
||||
if comm.rank != 0:
|
||||
if self.comm.rank != 0:
|
||||
d['dummy'] = 1 # so we don't get a warning about empty dict
|
||||
else:
|
||||
d = self.name2val
|
||||
out = d.copy() # Return the dict for unit testing purposes
|
||||
for fmt in self.output_formats:
|
||||
if isinstance(fmt, KVWriter):
|
||||
@@ -352,6 +344,9 @@ class Logger(object):
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
|
||||
def set_comm(self, comm):
|
||||
self.comm = comm
|
||||
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
@@ -366,7 +361,10 @@ class Logger(object):
|
||||
if isinstance(fmt, SeqWriter):
|
||||
fmt.writeseq(map(str, args))
|
||||
|
||||
def configure(dir=None, format_strs=None):
|
||||
def configure(dir=None, format_strs=None, comm=None):
|
||||
"""
|
||||
If comm is provided, average all numerical stats across that comm
|
||||
"""
|
||||
if dir is None:
|
||||
dir = os.getenv('OPENAI_LOGDIR')
|
||||
if dir is None:
|
||||
@@ -393,7 +391,7 @@ def configure(dir=None, format_strs=None):
|
||||
format_strs = filter(None, format_strs)
|
||||
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
|
||||
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
|
||||
log('Logging to %s'%dir)
|
||||
|
||||
def _configure_default_logger():
|
||||
@@ -406,17 +404,15 @@ def reset():
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
log('Reset logger')
|
||||
|
||||
class scoped_configure(object):
|
||||
def __init__(self, dir=None, format_strs=None):
|
||||
self.dir = dir
|
||||
self.format_strs = format_strs
|
||||
self.prevlogger = None
|
||||
def __enter__(self):
|
||||
self.prevlogger = get_current()
|
||||
configure(dir=self.dir, format_strs=self.format_strs)
|
||||
def __exit__(self, *args):
|
||||
@contextmanager
|
||||
def scoped_configure(dir=None, format_strs=None, comm=None):
|
||||
prevlogger = Logger.CURRENT
|
||||
configure(dir=dir, format_strs=format_strs, comm=comm)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = self.prevlogger
|
||||
Logger.CURRENT = prevlogger
|
||||
|
||||
# ================================================================
|
||||
|
||||
|
Reference in New Issue
Block a user