diff --git a/baselines/common/plot_util.py b/baselines/common/plot_util.py index 26b1613..e15c508 100644 --- a/baselines/common/plot_util.py +++ b/baselines/common/plot_util.py @@ -248,7 +248,10 @@ def plot_results( figsize=None, legend_outside=False, resample=0, - smooth_step=1.0 + smooth_step=1.0, + tiling='vertical', + xlabel=None, + ylabel=None ): ''' Plot multiple Results objects @@ -300,9 +303,23 @@ def plot_results( sk2r[splitkey].append(result) assert len(sk2r) > 0 assert isinstance(resample, int), "0: don't resample. : that many samples" - nrows = len(sk2r) - ncols = 1 - figsize = figsize or (6, 6 * nrows) + if tiling == 'vertical' or tiling is None: + nrows = len(sk2r) + ncols = 1 + elif tiling == 'horizontal': + ncols = len(sk2r) + nrows = 1 + elif tiling == 'symmetric': + import math + N = len(sk2r) + largest_divisor = 1 + for i in range(1, int(math.sqrt(N))+1): + if N % i == 0: + largest_divisor = i + ncols = largest_divisor + nrows = N // ncols + figsize = figsize or (6 * ncols, 6 * nrows) + f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize) groups = list(set(group_fn(result) for result in allresults)) @@ -316,7 +333,9 @@ def plot_results( g2c = defaultdict(int) sresults = sk2r[sk] gresults = defaultdict(list) - ax = axarr[isplit][0] + idx_row = isplit // ncols + idx_col = isplit % ncols + ax = axarr[idx_row][idx_col] for result in sresults: group = group_fn(result) g2c[group] += 1 @@ -355,7 +374,7 @@ def plot_results( ymean = np.mean(ys, axis=0) ystd = np.std(ys, axis=0) ystderr = ystd / np.sqrt(len(ys)) - l, = axarr[isplit][0].plot(usex, ymean, color=color) + l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color) g2l[group] = l if shaded_err: ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4) @@ -372,6 +391,17 @@ def plot_results( loc=2 if legend_outside else None, bbox_to_anchor=(1,1) if legend_outside else None) ax.set_title(sk) + # add xlabels, but only to the bottom row + if xlabel is not None: + for ax in axarr[-1]: + plt.sca(ax) + plt.xlabel(xlabel) + # add ylabels, but only to left column + if ylabel is not None: + for ax in axarr[:,0]: + plt.sca(ax) + plt.ylabel(ylabel) + return f, axarr def regression_analysis(df): diff --git a/baselines/common/tests/test_plot_util.py b/baselines/common/tests/test_plot_util.py new file mode 100644 index 0000000..be33308 --- /dev/null +++ b/baselines/common/tests/test_plot_util.py @@ -0,0 +1,17 @@ +# smoke tests of plot_util +from baselines.common import plot_util as pu +from baselines.common.tests.util import smoketest + + +def test_plot_util(): + nruns = 4 + logdirs = [smoketest('--alg=ppo2 --env=CartPole-v0 --num_timesteps=10000') for _ in range(nruns)] + data = pu.load_results(logdirs) + assert len(data) == 4 + + _, axes = pu.plot_results(data[:1]); assert len(axes) == 1 + _, axes = pu.plot_results(data, tiling='vertical'); assert axes.shape==(4,1) + _, axes = pu.plot_results(data, tiling='horizontal'); assert axes.shape==(1,4) + _, axes = pu.plot_results(data, tiling='symmetric'); assert axes.shape==(2,2) + _, axes = pu.plot_results(data, split_fn=lambda _: ''); assert len(axes) == 1 + diff --git a/baselines/common/tests/util.py b/baselines/common/tests/util.py index 441e3f7..b3d31fe 100644 --- a/baselines/common/tests/util.py +++ b/baselines/common/tests/util.py @@ -77,3 +77,16 @@ def rollout(env, model, n_trials): observations.append(episode_obs) return observations, actions, rewards + +def smoketest(argstr, **kwargs): + import tempfile + import subprocess + import os + argstr = 'python -m baselines.run ' + argstr + for key, value in kwargs: + argstr += ' --{}={}'.format(key, value) + tempdir = tempfile.mkdtemp() + env = os.environ.copy() + env['OPENAI_LOGDIR'] = tempdir + subprocess.run(argstr.split(' '), env=env) + return tempdir diff --git a/baselines/ddpg/test_smoke.py b/baselines/ddpg/test_smoke.py index a9fdc05..bd7eba6 100644 --- a/baselines/ddpg/test_smoke.py +++ b/baselines/ddpg/test_smoke.py @@ -1,10 +1,6 @@ -from multiprocessing import Process -import baselines.run - +from baselines.common.tests.util import smoketest def _run(argstr): - p = Process(target=baselines.run.main, args=('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' ')) - p.start() - p.join() + smoketest('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr) def test_popart(): _run('--normalize_returns=True --popart=True')