extra functionality in baselines.common.plot_util (#310)
* get plot_util from mt_experiments branch * add labels * unit tests for plot_util
This commit is contained in:
@@ -248,7 +248,10 @@ def plot_results(
|
|||||||
figsize=None,
|
figsize=None,
|
||||||
legend_outside=False,
|
legend_outside=False,
|
||||||
resample=0,
|
resample=0,
|
||||||
smooth_step=1.0
|
smooth_step=1.0,
|
||||||
|
tiling='vertical',
|
||||||
|
xlabel=None,
|
||||||
|
ylabel=None
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
Plot multiple Results objects
|
Plot multiple Results objects
|
||||||
@@ -300,9 +303,23 @@ def plot_results(
|
|||||||
sk2r[splitkey].append(result)
|
sk2r[splitkey].append(result)
|
||||||
assert len(sk2r) > 0
|
assert len(sk2r) > 0
|
||||||
assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
|
assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
|
||||||
nrows = len(sk2r)
|
if tiling == 'vertical' or tiling is None:
|
||||||
ncols = 1
|
nrows = len(sk2r)
|
||||||
figsize = figsize or (6, 6 * nrows)
|
ncols = 1
|
||||||
|
elif tiling == 'horizontal':
|
||||||
|
ncols = len(sk2r)
|
||||||
|
nrows = 1
|
||||||
|
elif tiling == 'symmetric':
|
||||||
|
import math
|
||||||
|
N = len(sk2r)
|
||||||
|
largest_divisor = 1
|
||||||
|
for i in range(1, int(math.sqrt(N))+1):
|
||||||
|
if N % i == 0:
|
||||||
|
largest_divisor = i
|
||||||
|
ncols = largest_divisor
|
||||||
|
nrows = N // ncols
|
||||||
|
figsize = figsize or (6 * ncols, 6 * nrows)
|
||||||
|
|
||||||
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
|
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
|
||||||
|
|
||||||
groups = list(set(group_fn(result) for result in allresults))
|
groups = list(set(group_fn(result) for result in allresults))
|
||||||
@@ -316,7 +333,9 @@ def plot_results(
|
|||||||
g2c = defaultdict(int)
|
g2c = defaultdict(int)
|
||||||
sresults = sk2r[sk]
|
sresults = sk2r[sk]
|
||||||
gresults = defaultdict(list)
|
gresults = defaultdict(list)
|
||||||
ax = axarr[isplit][0]
|
idx_row = isplit // ncols
|
||||||
|
idx_col = isplit % ncols
|
||||||
|
ax = axarr[idx_row][idx_col]
|
||||||
for result in sresults:
|
for result in sresults:
|
||||||
group = group_fn(result)
|
group = group_fn(result)
|
||||||
g2c[group] += 1
|
g2c[group] += 1
|
||||||
@@ -355,7 +374,7 @@ def plot_results(
|
|||||||
ymean = np.mean(ys, axis=0)
|
ymean = np.mean(ys, axis=0)
|
||||||
ystd = np.std(ys, axis=0)
|
ystd = np.std(ys, axis=0)
|
||||||
ystderr = ystd / np.sqrt(len(ys))
|
ystderr = ystd / np.sqrt(len(ys))
|
||||||
l, = axarr[isplit][0].plot(usex, ymean, color=color)
|
l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color)
|
||||||
g2l[group] = l
|
g2l[group] = l
|
||||||
if shaded_err:
|
if shaded_err:
|
||||||
ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
|
ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
|
||||||
@@ -372,6 +391,17 @@ def plot_results(
|
|||||||
loc=2 if legend_outside else None,
|
loc=2 if legend_outside else None,
|
||||||
bbox_to_anchor=(1,1) if legend_outside else None)
|
bbox_to_anchor=(1,1) if legend_outside else None)
|
||||||
ax.set_title(sk)
|
ax.set_title(sk)
|
||||||
|
# add xlabels, but only to the bottom row
|
||||||
|
if xlabel is not None:
|
||||||
|
for ax in axarr[-1]:
|
||||||
|
plt.sca(ax)
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
# add ylabels, but only to left column
|
||||||
|
if ylabel is not None:
|
||||||
|
for ax in axarr[:,0]:
|
||||||
|
plt.sca(ax)
|
||||||
|
plt.ylabel(ylabel)
|
||||||
|
|
||||||
return f, axarr
|
return f, axarr
|
||||||
|
|
||||||
def regression_analysis(df):
|
def regression_analysis(df):
|
||||||
|
17
baselines/common/tests/test_plot_util.py
Normal file
17
baselines/common/tests/test_plot_util.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# smoke tests of plot_util
|
||||||
|
from baselines.common import plot_util as pu
|
||||||
|
from baselines.common.tests.util import smoketest
|
||||||
|
|
||||||
|
|
||||||
|
def test_plot_util():
|
||||||
|
nruns = 4
|
||||||
|
logdirs = [smoketest('--alg=ppo2 --env=CartPole-v0 --num_timesteps=10000') for _ in range(nruns)]
|
||||||
|
data = pu.load_results(logdirs)
|
||||||
|
assert len(data) == 4
|
||||||
|
|
||||||
|
_, axes = pu.plot_results(data[:1]); assert len(axes) == 1
|
||||||
|
_, axes = pu.plot_results(data, tiling='vertical'); assert axes.shape==(4,1)
|
||||||
|
_, axes = pu.plot_results(data, tiling='horizontal'); assert axes.shape==(1,4)
|
||||||
|
_, axes = pu.plot_results(data, tiling='symmetric'); assert axes.shape==(2,2)
|
||||||
|
_, axes = pu.plot_results(data, split_fn=lambda _: ''); assert len(axes) == 1
|
||||||
|
|
@@ -77,3 +77,16 @@ def rollout(env, model, n_trials):
|
|||||||
observations.append(episode_obs)
|
observations.append(episode_obs)
|
||||||
return observations, actions, rewards
|
return observations, actions, rewards
|
||||||
|
|
||||||
|
|
||||||
|
def smoketest(argstr, **kwargs):
|
||||||
|
import tempfile
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
argstr = 'python -m baselines.run ' + argstr
|
||||||
|
for key, value in kwargs:
|
||||||
|
argstr += ' --{}={}'.format(key, value)
|
||||||
|
tempdir = tempfile.mkdtemp()
|
||||||
|
env = os.environ.copy()
|
||||||
|
env['OPENAI_LOGDIR'] = tempdir
|
||||||
|
subprocess.run(argstr.split(' '), env=env)
|
||||||
|
return tempdir
|
||||||
|
@@ -1,10 +1,6 @@
|
|||||||
from multiprocessing import Process
|
from baselines.common.tests.util import smoketest
|
||||||
import baselines.run
|
|
||||||
|
|
||||||
def _run(argstr):
|
def _run(argstr):
|
||||||
p = Process(target=baselines.run.main, args=('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
|
smoketest('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr)
|
||||||
p.start()
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
def test_popart():
|
def test_popart():
|
||||||
_run('--normalize_returns=True --popart=True')
|
_run('--normalize_returns=True --popart=True')
|
||||||
|
Reference in New Issue
Block a user