Fix ppo2 with MPI bug, other minor fixes (#735)

* joshim5 changes (width and height to WarpFrame wrapper)

* match network output with action distribution via a linear layer only if necessary (#167)

* support color vs. grayscale option in WarpFrame wrapper (#166)

* support color vs. grayscale option in WarpFrame wrapper

* Support color in other wrappers

* Updated per Peters suggestions

* fixing test failures

* ppo2 with microbatches (#168)

* pass microbatch_size to the model during construction

* microbatch fixes and test (#169)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* Peterz joshim5 subclass ppo2 model (#170)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* subclassing the model to make microbatched version of model WIP

* made microbatched model a subclass of ppo2 Model

* flake8 complaint

* mpi-less ppo2 (resolving merge conflict)

* flake8 and mpi4py imports in ppo2/model.py

* more un-mpying

* merge master

* updates to the benchmark viewer code + autopep8 (#184)

* viz docs and syntactic sugar wip

* update viewer yaml to use persistent volume claims

* move plot_util to baselines.common, update links

* use 1Tb hard drive for results viewer

* small updates to benchmark vizualizer code

* autopep8

* autopep8

* any folder can be a benchmark

* massage games image a little bit

* fixed --preload option in app.py

* remove preload from run_viewer.sh

* remove pdb breakpoints

* update bench-viewer.yaml

* fixed bug (#185)

* fixed bug 

it's wrong to do the else statement, because no other nodes would start.

* changed the fix slightly
This commit is contained in:
pzhokhov
2018-11-26 17:56:41 -08:00
committed by GitHub
parent 25ecb64821
commit 97e039127f
2 changed files with 15 additions and 17 deletions

View File

@@ -122,10 +122,9 @@ class Model(object):
self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
initialize()
else:
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
initialize()
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
if MPI is not None:
sync_from_root(sess, global_variables) #pylint: disable=E1101
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):

View File

@@ -5,7 +5,7 @@ matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none'
from baselines.bench.monitor import load_results
from baselines.common import plot_util
X_TIMESTEPS = 'timesteps'
X_EPISODES = 'episodes'
@@ -16,7 +16,7 @@ POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
EPISODES_WINDOW = 100
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise',
'darkgreen', 'tan', 'salmon', 'gold', 'lightpurple', 'darkred', 'darkblue']
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue']
def rolling_window(a, window):
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
@@ -50,7 +50,7 @@ def plot_curves(xy_list, xaxis, yaxis, title):
maxx = max(xy[0][-1] for xy in xy_list)
minx = 0
for (i, (x, y)) in enumerate(xy_list):
color = COLORS[i]
color = COLORS[i % len(COLORS)]
plt.scatter(x, y, s=2)
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) #So returns average of last EPISODE_WINDOW episodes
plt.plot(x, y_mean, color=color)
@@ -62,19 +62,18 @@ def plot_curves(xy_list, xaxis, yaxis, title):
fig.canvas.mpl_connect('resize_event', lambda event: plt.tight_layout())
plt.grid(True)
def plot_results(dirs, num_timesteps, xaxis, yaxis, task_name):
tslist = []
for dir in dirs:
ts = load_results(dir)
ts = ts[ts.l.cumsum() <= num_timesteps]
tslist.append(ts)
xy_list = [ts2xy(ts, xaxis, yaxis) for ts in tslist]
plot_curves(xy_list, xaxis, yaxis, task_name)
def split_by_task(taskpath):
return taskpath['dirname'].split('/')[-1].split('-')[0]
def plot_results(dirs, num_timesteps=10e6, xaxis=X_TIMESTEPS, yaxis=Y_REWARD, title='', split_fn=split_by_task):
results = plot_util.load_results(dirs)
plot_util.plot_results(results, xy_fn=lambda r: ts2xy(r['monitor'], xaxis, yaxis), split_fn=split_fn, average_group=True, resample=int(1e6))
# Example usage in jupyter-notebook
# from baselines import results_plotter
# from baselines.results_plotter import plot_results
# %matplotlib inline
# results_plotter.plot_results(["./log"], 10e6, results_plotter.X_TIMESTEPS, "Breakout")
# plot_results("./log")
# Here ./log is a directory containing the monitor.csv files
def main():