extra functionality in baselines.common.plot_util (#310)
* get plot_util from mt_experiments branch * add labels * unit tests for plot_util
This commit is contained in:
@@ -248,7 +248,10 @@ def plot_results(
|
||||
figsize=None,
|
||||
legend_outside=False,
|
||||
resample=0,
|
||||
smooth_step=1.0
|
||||
smooth_step=1.0,
|
||||
tiling='vertical',
|
||||
xlabel=None,
|
||||
ylabel=None
|
||||
):
|
||||
'''
|
||||
Plot multiple Results objects
|
||||
@@ -300,9 +303,23 @@ def plot_results(
|
||||
sk2r[splitkey].append(result)
|
||||
assert len(sk2r) > 0
|
||||
assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
|
||||
nrows = len(sk2r)
|
||||
ncols = 1
|
||||
figsize = figsize or (6, 6 * nrows)
|
||||
if tiling == 'vertical' or tiling is None:
|
||||
nrows = len(sk2r)
|
||||
ncols = 1
|
||||
elif tiling == 'horizontal':
|
||||
ncols = len(sk2r)
|
||||
nrows = 1
|
||||
elif tiling == 'symmetric':
|
||||
import math
|
||||
N = len(sk2r)
|
||||
largest_divisor = 1
|
||||
for i in range(1, int(math.sqrt(N))+1):
|
||||
if N % i == 0:
|
||||
largest_divisor = i
|
||||
ncols = largest_divisor
|
||||
nrows = N // ncols
|
||||
figsize = figsize or (6 * ncols, 6 * nrows)
|
||||
|
||||
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
|
||||
|
||||
groups = list(set(group_fn(result) for result in allresults))
|
||||
@@ -316,7 +333,9 @@ def plot_results(
|
||||
g2c = defaultdict(int)
|
||||
sresults = sk2r[sk]
|
||||
gresults = defaultdict(list)
|
||||
ax = axarr[isplit][0]
|
||||
idx_row = isplit // ncols
|
||||
idx_col = isplit % ncols
|
||||
ax = axarr[idx_row][idx_col]
|
||||
for result in sresults:
|
||||
group = group_fn(result)
|
||||
g2c[group] += 1
|
||||
@@ -355,7 +374,7 @@ def plot_results(
|
||||
ymean = np.mean(ys, axis=0)
|
||||
ystd = np.std(ys, axis=0)
|
||||
ystderr = ystd / np.sqrt(len(ys))
|
||||
l, = axarr[isplit][0].plot(usex, ymean, color=color)
|
||||
l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color)
|
||||
g2l[group] = l
|
||||
if shaded_err:
|
||||
ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
|
||||
@@ -372,6 +391,17 @@ def plot_results(
|
||||
loc=2 if legend_outside else None,
|
||||
bbox_to_anchor=(1,1) if legend_outside else None)
|
||||
ax.set_title(sk)
|
||||
# add xlabels, but only to the bottom row
|
||||
if xlabel is not None:
|
||||
for ax in axarr[-1]:
|
||||
plt.sca(ax)
|
||||
plt.xlabel(xlabel)
|
||||
# add ylabels, but only to left column
|
||||
if ylabel is not None:
|
||||
for ax in axarr[:,0]:
|
||||
plt.sca(ax)
|
||||
plt.ylabel(ylabel)
|
||||
|
||||
return f, axarr
|
||||
|
||||
def regression_analysis(df):
|
||||
|
17
baselines/common/tests/test_plot_util.py
Normal file
17
baselines/common/tests/test_plot_util.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# smoke tests of plot_util
|
||||
from baselines.common import plot_util as pu
|
||||
from baselines.common.tests.util import smoketest
|
||||
|
||||
|
||||
def test_plot_util():
|
||||
nruns = 4
|
||||
logdirs = [smoketest('--alg=ppo2 --env=CartPole-v0 --num_timesteps=10000') for _ in range(nruns)]
|
||||
data = pu.load_results(logdirs)
|
||||
assert len(data) == 4
|
||||
|
||||
_, axes = pu.plot_results(data[:1]); assert len(axes) == 1
|
||||
_, axes = pu.plot_results(data, tiling='vertical'); assert axes.shape==(4,1)
|
||||
_, axes = pu.plot_results(data, tiling='horizontal'); assert axes.shape==(1,4)
|
||||
_, axes = pu.plot_results(data, tiling='symmetric'); assert axes.shape==(2,2)
|
||||
_, axes = pu.plot_results(data, split_fn=lambda _: ''); assert len(axes) == 1
|
||||
|
@@ -77,3 +77,16 @@ def rollout(env, model, n_trials):
|
||||
observations.append(episode_obs)
|
||||
return observations, actions, rewards
|
||||
|
||||
|
||||
def smoketest(argstr, **kwargs):
|
||||
import tempfile
|
||||
import subprocess
|
||||
import os
|
||||
argstr = 'python -m baselines.run ' + argstr
|
||||
for key, value in kwargs:
|
||||
argstr += ' --{}={}'.format(key, value)
|
||||
tempdir = tempfile.mkdtemp()
|
||||
env = os.environ.copy()
|
||||
env['OPENAI_LOGDIR'] = tempdir
|
||||
subprocess.run(argstr.split(' '), env=env)
|
||||
return tempdir
|
||||
|
@@ -1,10 +1,6 @@
|
||||
from multiprocessing import Process
|
||||
import baselines.run
|
||||
|
||||
from baselines.common.tests.util import smoketest
|
||||
def _run(argstr):
|
||||
p = Process(target=baselines.run.main, args=('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
|
||||
p.start()
|
||||
p.join()
|
||||
smoketest('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr)
|
||||
|
||||
def test_popart():
|
||||
_run('--normalize_returns=True --popart=True')
|
||||
|
Reference in New Issue
Block a user