Everyrl initial commit & a few minor baselines changes (#226)

* everyrl initial commit

* add keep_buf argument to VecMonitor

* logger changes: set_comm and fix to mpi_mean functionality

* if filename not provided, don't create ResultsWriter

* change variable syncing function to simplify its usage. now you should initialize from all mpi processes

* everyrl coinrun changes

* tf_distr changes, bugfix

* get_one

* bring back get_next to temporarily restore code

* lint fixes

* fix test

* rename profile function

* rename gaussian

* fix coinrun training script
This commit is contained in:
John Schulman
2019-01-24 15:43:26 -08:00
committed by Peter Zhokhov
parent cd8d3389ba
commit 82ebd4a153
5 changed files with 99 additions and 98 deletions

View File

@@ -16,11 +16,13 @@ class Monitor(Wrapper):
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
Wrapper.__init__(self, env=env)
self.tstart = time.time()
self.results_writer = ResultsWriter(
filename,
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords
)
if filename:
self.results_writer = ResultsWriter(filename,
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords
)
else:
self.results_writer = None
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
@@ -68,8 +70,9 @@ class Monitor(Wrapper):
self.episode_lengths.append(eplen)
self.episode_times.append(time.time() - self.tstart)
epinfo.update(self.current_reset_info)
self.results_writer.write_row(epinfo)
if self.results_writer:
self.results_writer.write_row(epinfo)
assert isinstance(info, dict)
if isinstance(info, dict):
info['episode'] = epinfo
@@ -96,24 +99,21 @@ class LoadMonitorResultsError(Exception):
class ResultsWriter(object):
def __init__(self, filename=None, header='', extra_keys=()):
def __init__(self, filename, header='', extra_keys=()):
self.extra_keys = extra_keys
if filename is None:
self.f = None
self.logger = None
else:
if not filename.endswith(Monitor.EXT):
if osp.isdir(filename):
filename = osp.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.f = open(filename, "wt")
if isinstance(header, dict):
header = '# {} \n'.format(json.dumps(header))
self.f.write(header)
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
self.logger.writeheader()
self.f.flush()
assert filename is not None
if not filename.endswith(Monitor.EXT):
if osp.isdir(filename):
filename = osp.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.f = open(filename, "wt")
if isinstance(header, dict):
header = '# {} \n'.format(json.dumps(header))
self.f.write(header)
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
self.logger.writeheader()
self.f.flush()
def write_row(self, epinfo):
if self.logger:
@@ -121,7 +121,6 @@ class ResultsWriter(object):
self.f.flush()
def get_monitor_files(dir):
return glob(osp.join(dir, "*" + Monitor.EXT))

View File

@@ -19,15 +19,10 @@ def sync_from_root(sess, variables, comm=None):
variables: all parameter variables including optimizer's
"""
if comm is None: comm = MPI.COMM_WORLD
rank = comm.Get_rank()
for var in variables:
if rank == 0:
comm.Bcast(sess.run(var))
else:
import tensorflow as tf
returned_var = np.empty(var.shape, dtype='float32')
comm.Bcast(returned_var)
sess.run(tf.assign(var, returned_var))
import tensorflow as tf
values = comm.bcast(sess.run(variables))
sess.run([tf.assign(var, val)
for (var, val) in zip(variables, values)])
def gpu_count():
"""

View File

@@ -6,21 +6,22 @@ from baselines.common import mpi_util
def test_mpi_weighted_mean():
from mpi4py import MPI
comm = MPI.COMM_WORLD
if comm.rank == 0:
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
elif comm.rank == 1:
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
else:
raise NotImplementedError
with logger.scoped_configure(comm=comm):
if comm.rank == 0:
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
elif comm.rank == 1:
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
else:
raise NotImplementedError
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
if comm.rank == 0:
assert d == correctval, f'{d} != {correctval}'
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
if comm.rank == 0:
assert d == correctval, f'{d} != {correctval}'
for name, (val, count) in name2valcount.items():
for _ in range(count):
logger.logkv_mean(name, val)
d2 = logger.dumpkvs(mpi_mean=True)
if comm.rank == 0:
assert d2 == correctval
for name, (val, count) in name2valcount.items():
for _ in range(count):
logger.logkv_mean(name, val)
d2 = logger.dumpkvs()
if comm.rank == 0:
assert d2 == correctval

View File

@@ -2,15 +2,22 @@ from . import VecEnvWrapper
from baselines.bench.monitor import ResultsWriter
import numpy as np
import time
from collections import deque
class VecMonitor(VecEnvWrapper):
def __init__(self, venv, filename=None):
def __init__(self, venv, filename=None, keep_buf=0):
VecEnvWrapper.__init__(self, venv)
self.eprets = None
self.eplens = None
self.tstart = time.time()
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
if filename:
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
else:
self.results_writer = None
self.keep_buf = keep_buf
if self.keep_buf:
self.epret_buf = deque([], maxlen=keep_buf)
self.eplen_buf = deque([], maxlen=keep_buf)
def reset(self):
obs = self.venv.reset()
@@ -28,10 +35,13 @@ class VecMonitor(VecEnvWrapper):
if done:
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
info['episode'] = epinfo
if self.keep_buf:
self.epret_buf.append(ret)
self.eplen_buf.append(eplen)
self.eprets[i] = 0
self.eplens[i] = 0
self.results_writer.write_row(epinfo)
if self.results_writer:
self.results_writer.write_row(epinfo)
newinfos.append(info)
return obs, rews, dones, newinfos

View File

@@ -7,6 +7,7 @@ import time
import datetime
import tempfile
from collections import defaultdict
from contextlib import contextmanager
DEBUG = 10
INFO = 20
@@ -211,17 +212,11 @@ def logkvs(d):
for (k, v) in d.items():
logkv(k, v)
def dumpkvs(mpi_mean=False):
def dumpkvs():
"""
Write all of the diagnostics from the current iteration
mpi_mean: whether to average across MPI workers. mpi_mean=False just
has each worker write its own stats (and under default settings
non-root workers don't write anything), whereas mpi_mean=True has
the root worker collect all of the stats and write the average,
and no one else writes anything.
"""
return get_current().dumpkvs(mpi_mean=mpi_mean)
return get_current().dumpkvs()
def getkvs():
return get_current().name2val
@@ -252,6 +247,9 @@ def set_level(level):
"""
get_current().set_level(level)
def set_comm(comm):
get_current().set_comm(comm)
def get_dir():
"""
Get directory that log files are being written to.
@@ -262,18 +260,14 @@ def get_dir():
record_tabular = logkv
dump_tabular = dumpkvs
class ProfileKV:
"""
Usage:
with logger.ProfileKV("interesting_scope"):
code
"""
def __init__(self, n):
self.n = "wait_" + n
def __enter__(self):
self.t1 = time.time()
def __exit__(self ,type, value, traceback):
get_current().name2val[self.n] += time.time() - self.t1
@contextmanager
def profile_kv(scopename):
logkey = 'wait_' + scopename
tstart = time.time()
try:
yield
finally:
get_current().name2val[logkey] += time.time() - tstart
def profile(n):
"""
@@ -283,7 +277,7 @@ def profile(n):
"""
def decorator_with_name(func):
def func_wrapper(*args, **kwargs):
with ProfileKV(n):
with profile_kv(n):
return func(*args, **kwargs)
return func_wrapper
return decorator_with_name
@@ -305,12 +299,13 @@ class Logger(object):
# So that you can still log to the terminal without setting up any output files
CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir, output_formats):
def __init__(self, dir, output_formats, comm=None):
self.name2val = defaultdict(float) # values this iteration
self.name2cnt = defaultdict(int)
self.level = INFO
self.dir = dir
self.output_formats = output_formats
self.comm = comm
# Logging API, forwarded
# ----------------------------------------
@@ -322,19 +317,16 @@ class Logger(object):
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
self.name2cnt[key] = cnt + 1
def dumpkvs(self, mpi_mean=False):
if self.level == DISABLED: return
if mpi_mean:
def dumpkvs(self):
if self.comm is None:
d = self.name2val
else:
from baselines.common import mpi_util
from mpi4py import MPI
comm = MPI.COMM_WORLD
d = mpi_util.mpi_weighted_mean(comm,
d = mpi_util.mpi_weighted_mean(self.comm,
{name : (val, self.name2cnt.get(name, 1))
for (name, val) in self.name2val.items()})
if comm.rank != 0:
if self.comm.rank != 0:
d['dummy'] = 1 # so we don't get a warning about empty dict
else:
d = self.name2val
out = d.copy() # Return the dict for unit testing purposes
for fmt in self.output_formats:
if isinstance(fmt, KVWriter):
@@ -352,6 +344,9 @@ class Logger(object):
def set_level(self, level):
self.level = level
def set_comm(self, comm):
self.comm = comm
def get_dir(self):
return self.dir
@@ -366,7 +361,10 @@ class Logger(object):
if isinstance(fmt, SeqWriter):
fmt.writeseq(map(str, args))
def configure(dir=None, format_strs=None):
def configure(dir=None, format_strs=None, comm=None):
"""
If comm is provided, average all numerical stats across that comm
"""
if dir is None:
dir = os.getenv('OPENAI_LOGDIR')
if dir is None:
@@ -393,7 +391,7 @@ def configure(dir=None, format_strs=None):
format_strs = filter(None, format_strs)
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
log('Logging to %s'%dir)
def _configure_default_logger():
@@ -406,17 +404,15 @@ def reset():
Logger.CURRENT = Logger.DEFAULT
log('Reset logger')
class scoped_configure(object):
def __init__(self, dir=None, format_strs=None):
self.dir = dir
self.format_strs = format_strs
self.prevlogger = None
def __enter__(self):
self.prevlogger = get_current()
configure(dir=self.dir, format_strs=self.format_strs)
def __exit__(self, *args):
@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
prevlogger = Logger.CURRENT
configure(dir=dir, format_strs=format_strs, comm=comm)
try:
yield
finally:
Logger.CURRENT.close()
Logger.CURRENT = self.prevlogger
Logger.CURRENT = prevlogger
# ================================================================