* everyrl initial commit

* add keep_buf argument to VecMonitor

* logger changes: set_comm and fix to mpi_mean functionality

* if filename not provided, don't create ResultsWriter

* change variable syncing function to simplify its usage. now you should initialize from all mpi processes

* everyrl coinrun changes

* tf_distr changes, bugfix

* get_one

* bring back get_next to temporarily restore code

* lint fixes

* fix test

* rename profile function

* rename gaussian

* fix coinrun training script

* rl19

* remove everyrl dir which appeared in the merge for some reason

* readme

* fiddle with ddpg

* make ddpg work

* steps_total argument

* gpu count

* clean up hyperparams and shape math

* logging + saving

* configuration stuff

* fixes, smoke tests

* fix stats

* make load_results return dicts -- easier to create the same kind of objects with some other mechanism for passing to downstream functions

* benchmarks

* fix tests

* add dqn to tests, fix it

* minor

* turned annotated transformer (pytorch) into a script

* more refactoring

* jax stuff

* cluster

* minor

* copy & paste alec code

* sign error

* add huber, rename some parameters, snapshotting off by default

* remove jax stuff

* minor

* move maze env

* minor

* remove trailing spaces

* remove trailing space

* lint

* fix test breakage due to gym update

* rename function

* move maze back to codegen

* get recurrent ppo working

* enable both lstm and gru

* script to print table of benchmark results

* various

* fix dqn

* add fixup initializer, remove lastrew

* organize logging stats

* fix silly bug

* refactor models

* fix mpi usage

* check sync

* minor

* change vf coef, hps

* clean up slicing in ppo

* minor fixes

* caching transformer

* docstrings

* xf fixes

* get rid of 'B' and 'BT' arguments

* minor

* transformer example

* remove output_kind from base class until we have a better idea how to use it

* add comments, revert maze stuff

* flake8

* codegen lint

* fix codegen tests

* responded to peter's comments

* lint fixes
This commit is contained in:
John Schulman
2019-02-11 15:23:13 -08:00
committed by Peter Zhokhov
parent ecf5394226
commit fb6fd51fe6
4 changed files with 28 additions and 8 deletions

View File

@@ -17,9 +17,16 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer):
num_tasks = self.comm.Get_size()
buf = np.zeros(sum(sizes), np.float32)
sess = tf.get_default_session()
assert sess is not None
countholder = [0] # Counts how many times _collect_grads has been called
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
def _collect_grads(flat_grad):
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
np.divide(buf, float(num_tasks), out=buf)
if countholder[0] % 100 == 0:
check_synced(sess, self.comm, stat)
countholder[0] += 1
return buf
avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32)
@@ -27,5 +34,13 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer):
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
for g, (_, v) in zip(avg_grads, grads_and_vars)]
return avg_grads_and_vars
def check_synced(sess, comm, tfstat):
"""
Check that 'tfstat' evaluates to the same thing on every MPI worker
"""
localval = sess.run(tfstat)
vals = comm.gather(localval)
if comm.rank == 0:
assert all(val==vals[0] for val in vals[1:])

View File

@@ -4,6 +4,7 @@ import platform
import shutil
import subprocess
import warnings
import sys
try:
from mpi4py import MPI
@@ -35,13 +36,15 @@ def gpu_count():
def setup_mpi_gpus():
"""
Set CUDA_VISIBLE_DEVICES using MPI.
Set CUDA_VISIBLE_DEVICES to MPI rank if not already set
"""
num_gpus = gpu_count()
if num_gpus == 0:
return
local_rank, _ = get_local_rank_size(MPI.COMM_WORLD)
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank % num_gpus)
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
if sys.platform == 'darwin': # This Assumes if you're on OSX you're just
ids = [] # doing a smoke test and don't want GPUs
else:
lrank, _lsize = get_local_rank_size(MPI.COMM_WORLD)
ids = [lrank]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, ids))
def get_local_rank_size(comm):
"""
@@ -127,3 +130,4 @@ def mpi_weighted_mean(comm, local_name2valcount):
return {name : name2sum[name] / name2count[name] for name in name2sum}
else:
return {}

View File

@@ -137,7 +137,6 @@ class VecEnv(ABC):
self.viewer = rendering.SimpleImageViewer()
return self.viewer
class VecEnvWrapper(VecEnv):
"""
An environment wrapper that applies to an entire batch

View File

@@ -9,6 +9,7 @@ class VecMonitor(VecEnvWrapper):
VecEnvWrapper.__init__(self, venv)
self.eprets = None
self.eplens = None
self.epcount = 0
self.tstart = time.time()
if filename:
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
@@ -38,6 +39,7 @@ class VecMonitor(VecEnvWrapper):
if self.keep_buf:
self.epret_buf.append(ret)
self.eplen_buf.append(eplen)
self.epcount += 1
self.eprets[i] = 0
self.eplens[i] = 0
if self.results_writer: