extra functionality in baselines.common.plot_util (#310)

* get plot_util from mt_experiments branch

* add labels

* unit tests for plot_util
This commit is contained in:
pzhokhov
2019-04-17 15:17:27 -07:00
committed by Peter Zhokhov
parent b83a66527d
commit a93dde3b2b
4 changed files with 68 additions and 12 deletions

View File

@@ -248,7 +248,10 @@ def plot_results(
figsize=None,
legend_outside=False,
resample=0,
smooth_step=1.0
smooth_step=1.0,
tiling='vertical',
xlabel=None,
ylabel=None
):
'''
Plot multiple Results objects
@@ -300,9 +303,23 @@ def plot_results(
sk2r[splitkey].append(result)
assert len(sk2r) > 0
assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
nrows = len(sk2r)
ncols = 1
figsize = figsize or (6, 6 * nrows)
if tiling == 'vertical' or tiling is None:
nrows = len(sk2r)
ncols = 1
elif tiling == 'horizontal':
ncols = len(sk2r)
nrows = 1
elif tiling == 'symmetric':
import math
N = len(sk2r)
largest_divisor = 1
for i in range(1, int(math.sqrt(N))+1):
if N % i == 0:
largest_divisor = i
ncols = largest_divisor
nrows = N // ncols
figsize = figsize or (6 * ncols, 6 * nrows)
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
groups = list(set(group_fn(result) for result in allresults))
@@ -316,7 +333,9 @@ def plot_results(
g2c = defaultdict(int)
sresults = sk2r[sk]
gresults = defaultdict(list)
ax = axarr[isplit][0]
idx_row = isplit // ncols
idx_col = isplit % ncols
ax = axarr[idx_row][idx_col]
for result in sresults:
group = group_fn(result)
g2c[group] += 1
@@ -355,7 +374,7 @@ def plot_results(
ymean = np.mean(ys, axis=0)
ystd = np.std(ys, axis=0)
ystderr = ystd / np.sqrt(len(ys))
l, = axarr[isplit][0].plot(usex, ymean, color=color)
l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color)
g2l[group] = l
if shaded_err:
ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
@@ -372,6 +391,17 @@ def plot_results(
loc=2 if legend_outside else None,
bbox_to_anchor=(1,1) if legend_outside else None)
ax.set_title(sk)
# add xlabels, but only to the bottom row
if xlabel is not None:
for ax in axarr[-1]:
plt.sca(ax)
plt.xlabel(xlabel)
# add ylabels, but only to left column
if ylabel is not None:
for ax in axarr[:,0]:
plt.sca(ax)
plt.ylabel(ylabel)
return f, axarr
def regression_analysis(df):

View File

@@ -0,0 +1,17 @@
# smoke tests of plot_util
from baselines.common import plot_util as pu
from baselines.common.tests.util import smoketest
def test_plot_util():
nruns = 4
logdirs = [smoketest('--alg=ppo2 --env=CartPole-v0 --num_timesteps=10000') for _ in range(nruns)]
data = pu.load_results(logdirs)
assert len(data) == 4
_, axes = pu.plot_results(data[:1]); assert len(axes) == 1
_, axes = pu.plot_results(data, tiling='vertical'); assert axes.shape==(4,1)
_, axes = pu.plot_results(data, tiling='horizontal'); assert axes.shape==(1,4)
_, axes = pu.plot_results(data, tiling='symmetric'); assert axes.shape==(2,2)
_, axes = pu.plot_results(data, split_fn=lambda _: ''); assert len(axes) == 1

View File

@@ -77,3 +77,16 @@ def rollout(env, model, n_trials):
observations.append(episode_obs)
return observations, actions, rewards
def smoketest(argstr, **kwargs):
import tempfile
import subprocess
import os
argstr = 'python -m baselines.run ' + argstr
for key, value in kwargs:
argstr += ' --{}={}'.format(key, value)
tempdir = tempfile.mkdtemp()
env = os.environ.copy()
env['OPENAI_LOGDIR'] = tempdir
subprocess.run(argstr.split(' '), env=env)
return tempdir

View File

@@ -1,10 +1,6 @@
from multiprocessing import Process
import baselines.run
from baselines.common.tests.util import smoketest
def _run(argstr):
p = Process(target=baselines.run.main, args=('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
p.start()
p.join()
smoketest('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr)
def test_popart():
_run('--normalize_returns=True --popart=True')