Everyrl initial commit & a few minor baselines changes (#226)

* everyrl initial commit

* add keep_buf argument to VecMonitor

* logger changes: set_comm and fix to mpi_mean functionality

* if filename not provided, don't create ResultsWriter

* change variable syncing function to simplify its usage. now you should initialize from all mpi processes

* everyrl coinrun changes

* tf_distr changes, bugfix

* get_one

* bring back get_next to temporarily restore code

* lint fixes

* fix test

* rename profile function

* rename gaussian

* fix coinrun training script
This commit is contained in:
John Schulman
2019-01-24 15:43:26 -08:00
committed by Peter Zhokhov
parent cd8d3389ba
commit 82ebd4a153
5 changed files with 99 additions and 98 deletions

View File

@@ -16,11 +16,13 @@ class Monitor(Wrapper):
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()): def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
Wrapper.__init__(self, env=env) Wrapper.__init__(self, env=env)
self.tstart = time.time() self.tstart = time.time()
self.results_writer = ResultsWriter( if filename:
filename, self.results_writer = ResultsWriter(filename,
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id}, header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords extra_keys=reset_keywords + info_keywords
) )
else:
self.results_writer = None
self.reset_keywords = reset_keywords self.reset_keywords = reset_keywords
self.info_keywords = info_keywords self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets self.allow_early_resets = allow_early_resets
@@ -68,8 +70,9 @@ class Monitor(Wrapper):
self.episode_lengths.append(eplen) self.episode_lengths.append(eplen)
self.episode_times.append(time.time() - self.tstart) self.episode_times.append(time.time() - self.tstart)
epinfo.update(self.current_reset_info) epinfo.update(self.current_reset_info)
self.results_writer.write_row(epinfo) if self.results_writer:
self.results_writer.write_row(epinfo)
assert isinstance(info, dict)
if isinstance(info, dict): if isinstance(info, dict):
info['episode'] = epinfo info['episode'] = epinfo
@@ -96,24 +99,21 @@ class LoadMonitorResultsError(Exception):
class ResultsWriter(object): class ResultsWriter(object):
def __init__(self, filename=None, header='', extra_keys=()): def __init__(self, filename, header='', extra_keys=()):
self.extra_keys = extra_keys self.extra_keys = extra_keys
if filename is None: assert filename is not None
self.f = None if not filename.endswith(Monitor.EXT):
self.logger = None if osp.isdir(filename):
else: filename = osp.join(filename, Monitor.EXT)
if not filename.endswith(Monitor.EXT): else:
if osp.isdir(filename): filename = filename + "." + Monitor.EXT
filename = osp.join(filename, Monitor.EXT) self.f = open(filename, "wt")
else: if isinstance(header, dict):
filename = filename + "." + Monitor.EXT header = '# {} \n'.format(json.dumps(header))
self.f = open(filename, "wt") self.f.write(header)
if isinstance(header, dict): self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
header = '# {} \n'.format(json.dumps(header)) self.logger.writeheader()
self.f.write(header) self.f.flush()
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
self.logger.writeheader()
self.f.flush()
def write_row(self, epinfo): def write_row(self, epinfo):
if self.logger: if self.logger:
@@ -121,7 +121,6 @@ class ResultsWriter(object):
self.f.flush() self.f.flush()
def get_monitor_files(dir): def get_monitor_files(dir):
return glob(osp.join(dir, "*" + Monitor.EXT)) return glob(osp.join(dir, "*" + Monitor.EXT))

View File

@@ -19,15 +19,10 @@ def sync_from_root(sess, variables, comm=None):
variables: all parameter variables including optimizer's variables: all parameter variables including optimizer's
""" """
if comm is None: comm = MPI.COMM_WORLD if comm is None: comm = MPI.COMM_WORLD
rank = comm.Get_rank() import tensorflow as tf
for var in variables: values = comm.bcast(sess.run(variables))
if rank == 0: sess.run([tf.assign(var, val)
comm.Bcast(sess.run(var)) for (var, val) in zip(variables, values)])
else:
import tensorflow as tf
returned_var = np.empty(var.shape, dtype='float32')
comm.Bcast(returned_var)
sess.run(tf.assign(var, returned_var))
def gpu_count(): def gpu_count():
""" """

View File

@@ -6,21 +6,22 @@ from baselines.common import mpi_util
def test_mpi_weighted_mean(): def test_mpi_weighted_mean():
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
if comm.rank == 0: with logger.scoped_configure(comm=comm):
name2valcount = {'a' : (10, 2), 'b' : (20,3)} if comm.rank == 0:
elif comm.rank == 1: name2valcount = {'a' : (10, 2), 'b' : (20,3)}
name2valcount = {'a' : (19, 1), 'c' : (42,3)} elif comm.rank == 1:
else: name2valcount = {'a' : (19, 1), 'c' : (42,3)}
raise NotImplementedError else:
raise NotImplementedError
d = mpi_util.mpi_weighted_mean(comm, name2valcount) d = mpi_util.mpi_weighted_mean(comm, name2valcount)
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42} correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
if comm.rank == 0: if comm.rank == 0:
assert d == correctval, f'{d} != {correctval}' assert d == correctval, f'{d} != {correctval}'
for name, (val, count) in name2valcount.items(): for name, (val, count) in name2valcount.items():
for _ in range(count): for _ in range(count):
logger.logkv_mean(name, val) logger.logkv_mean(name, val)
d2 = logger.dumpkvs(mpi_mean=True) d2 = logger.dumpkvs()
if comm.rank == 0: if comm.rank == 0:
assert d2 == correctval assert d2 == correctval

View File

@@ -2,15 +2,22 @@ from . import VecEnvWrapper
from baselines.bench.monitor import ResultsWriter from baselines.bench.monitor import ResultsWriter
import numpy as np import numpy as np
import time import time
from collections import deque
class VecMonitor(VecEnvWrapper): class VecMonitor(VecEnvWrapper):
def __init__(self, venv, filename=None): def __init__(self, venv, filename=None, keep_buf=0):
VecEnvWrapper.__init__(self, venv) VecEnvWrapper.__init__(self, venv)
self.eprets = None self.eprets = None
self.eplens = None self.eplens = None
self.tstart = time.time() self.tstart = time.time()
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}) if filename:
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
else:
self.results_writer = None
self.keep_buf = keep_buf
if self.keep_buf:
self.epret_buf = deque([], maxlen=keep_buf)
self.eplen_buf = deque([], maxlen=keep_buf)
def reset(self): def reset(self):
obs = self.venv.reset() obs = self.venv.reset()
@@ -28,10 +35,13 @@ class VecMonitor(VecEnvWrapper):
if done: if done:
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
info['episode'] = epinfo info['episode'] = epinfo
if self.keep_buf:
self.epret_buf.append(ret)
self.eplen_buf.append(eplen)
self.eprets[i] = 0 self.eprets[i] = 0
self.eplens[i] = 0 self.eplens[i] = 0
self.results_writer.write_row(epinfo) if self.results_writer:
self.results_writer.write_row(epinfo)
newinfos.append(info) newinfos.append(info)
return obs, rews, dones, newinfos return obs, rews, dones, newinfos

View File

@@ -7,6 +7,7 @@ import time
import datetime import datetime
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
DEBUG = 10 DEBUG = 10
INFO = 20 INFO = 20
@@ -211,17 +212,11 @@ def logkvs(d):
for (k, v) in d.items(): for (k, v) in d.items():
logkv(k, v) logkv(k, v)
def dumpkvs(mpi_mean=False): def dumpkvs():
""" """
Write all of the diagnostics from the current iteration Write all of the diagnostics from the current iteration
mpi_mean: whether to average across MPI workers. mpi_mean=False just
has each worker write its own stats (and under default settings
non-root workers don't write anything), whereas mpi_mean=True has
the root worker collect all of the stats and write the average,
and no one else writes anything.
""" """
return get_current().dumpkvs(mpi_mean=mpi_mean) return get_current().dumpkvs()
def getkvs(): def getkvs():
return get_current().name2val return get_current().name2val
@@ -252,6 +247,9 @@ def set_level(level):
""" """
get_current().set_level(level) get_current().set_level(level)
def set_comm(comm):
get_current().set_comm(comm)
def get_dir(): def get_dir():
""" """
Get directory that log files are being written to. Get directory that log files are being written to.
@@ -262,18 +260,14 @@ def get_dir():
record_tabular = logkv record_tabular = logkv
dump_tabular = dumpkvs dump_tabular = dumpkvs
class ProfileKV: @contextmanager
""" def profile_kv(scopename):
Usage: logkey = 'wait_' + scopename
with logger.ProfileKV("interesting_scope"): tstart = time.time()
code try:
""" yield
def __init__(self, n): finally:
self.n = "wait_" + n get_current().name2val[logkey] += time.time() - tstart
def __enter__(self):
self.t1 = time.time()
def __exit__(self ,type, value, traceback):
get_current().name2val[self.n] += time.time() - self.t1
def profile(n): def profile(n):
""" """
@@ -283,7 +277,7 @@ def profile(n):
""" """
def decorator_with_name(func): def decorator_with_name(func):
def func_wrapper(*args, **kwargs): def func_wrapper(*args, **kwargs):
with ProfileKV(n): with profile_kv(n):
return func(*args, **kwargs) return func(*args, **kwargs)
return func_wrapper return func_wrapper
return decorator_with_name return decorator_with_name
@@ -305,12 +299,13 @@ class Logger(object):
# So that you can still log to the terminal without setting up any output files # So that you can still log to the terminal without setting up any output files
CURRENT = None # Current logger being used by the free functions above CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir, output_formats): def __init__(self, dir, output_formats, comm=None):
self.name2val = defaultdict(float) # values this iteration self.name2val = defaultdict(float) # values this iteration
self.name2cnt = defaultdict(int) self.name2cnt = defaultdict(int)
self.level = INFO self.level = INFO
self.dir = dir self.dir = dir
self.output_formats = output_formats self.output_formats = output_formats
self.comm = comm
# Logging API, forwarded # Logging API, forwarded
# ---------------------------------------- # ----------------------------------------
@@ -322,19 +317,16 @@ class Logger(object):
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1) self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
self.name2cnt[key] = cnt + 1 self.name2cnt[key] = cnt + 1
def dumpkvs(self, mpi_mean=False): def dumpkvs(self):
if self.level == DISABLED: return if self.comm is None:
if mpi_mean: d = self.name2val
else:
from baselines.common import mpi_util from baselines.common import mpi_util
from mpi4py import MPI d = mpi_util.mpi_weighted_mean(self.comm,
comm = MPI.COMM_WORLD
d = mpi_util.mpi_weighted_mean(comm,
{name : (val, self.name2cnt.get(name, 1)) {name : (val, self.name2cnt.get(name, 1))
for (name, val) in self.name2val.items()}) for (name, val) in self.name2val.items()})
if comm.rank != 0: if self.comm.rank != 0:
d['dummy'] = 1 # so we don't get a warning about empty dict d['dummy'] = 1 # so we don't get a warning about empty dict
else:
d = self.name2val
out = d.copy() # Return the dict for unit testing purposes out = d.copy() # Return the dict for unit testing purposes
for fmt in self.output_formats: for fmt in self.output_formats:
if isinstance(fmt, KVWriter): if isinstance(fmt, KVWriter):
@@ -352,6 +344,9 @@ class Logger(object):
def set_level(self, level): def set_level(self, level):
self.level = level self.level = level
def set_comm(self, comm):
self.comm = comm
def get_dir(self): def get_dir(self):
return self.dir return self.dir
@@ -366,7 +361,10 @@ class Logger(object):
if isinstance(fmt, SeqWriter): if isinstance(fmt, SeqWriter):
fmt.writeseq(map(str, args)) fmt.writeseq(map(str, args))
def configure(dir=None, format_strs=None): def configure(dir=None, format_strs=None, comm=None):
"""
If comm is provided, average all numerical stats across that comm
"""
if dir is None: if dir is None:
dir = os.getenv('OPENAI_LOGDIR') dir = os.getenv('OPENAI_LOGDIR')
if dir is None: if dir is None:
@@ -393,7 +391,7 @@ def configure(dir=None, format_strs=None):
format_strs = filter(None, format_strs) format_strs = filter(None, format_strs)
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
log('Logging to %s'%dir) log('Logging to %s'%dir)
def _configure_default_logger(): def _configure_default_logger():
@@ -406,17 +404,15 @@ def reset():
Logger.CURRENT = Logger.DEFAULT Logger.CURRENT = Logger.DEFAULT
log('Reset logger') log('Reset logger')
class scoped_configure(object): @contextmanager
def __init__(self, dir=None, format_strs=None): def scoped_configure(dir=None, format_strs=None, comm=None):
self.dir = dir prevlogger = Logger.CURRENT
self.format_strs = format_strs configure(dir=dir, format_strs=format_strs, comm=comm)
self.prevlogger = None try:
def __enter__(self): yield
self.prevlogger = get_current() finally:
configure(dir=self.dir, format_strs=self.format_strs)
def __exit__(self, *args):
Logger.CURRENT.close() Logger.CURRENT.close()
Logger.CURRENT = self.prevlogger Logger.CURRENT = prevlogger
# ================================================================ # ================================================================