docstrings in plot_util

This commit is contained in:
Peter Zhokhov
2018-11-05 10:02:45 -08:00
parent 1fc5e137b2
commit 527acf123f
2 changed files with 179 additions and 69 deletions

View File

@@ -36,69 +36,115 @@ def smooth(y, radius, mode='two_sided', valid_only=False):
out[:radius] = np.nan
return out
def one_sided_ema(xolds, yolds, low, high, n, decay_steps=1.):
def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
'''
perform one-sided (causal) EMA (exponential moving average)
smoothing and resampling to an even grid with n points.
Does not do extrapolation, so we assume
xolds[0] <= low && high <= xolds[-1]
Arguments:
xolds: array or list - x values of data. Needs to be sorted in ascending order
yolds: array of list - y values of data. Has to have the same length as xolds
low: float - min value of the new x grid. By default equals to xolds[0]
high: float - max value of the new x grid. By default equals to xolds[-1]
n: int - number of points in new x grid
decay_steps: float - EMA decay factor, expressed in new x grid steps.
low_counts_threshold: float or int
- y values with counts less than this value will be set to NaN
Returns:
tuple sum_ys, count_ys where
xs - array with new x grid
ys - array of EMA of y at each point of the new x grid
count_ys - array of EMA of y counts at each point of the new x grid
'''
low = xolds[0] if low is None else low
high = xolds[-1] if high is None else high
assert xolds[0] <= low, 'low = {} < xolds[0] = {} - extrapolation not permitted!'.format(low, xolds[0])
assert xolds[-1] >= high, 'high = {} > xolds[-1] = {} - extrapolation not permitted!'.format(high, xolds[-1])
assert len(xolds) == len(yolds), 'length of xolds ({}) and yolds ({}) do not match!'.format(len(xolds), len(yolds))
xolds = xolds.astype('float64')
yolds = yolds.astype('float64')
luoi = 0 # last unused old index
sumy = 0.
county = 0.
sum_y = 0.
count_y = 0.
xnews = np.linspace(low, high, n)
decay_period = (high - low) / (n - 1) * decay_steps
interstepdecay = np.exp(- 1. / decay_steps)
sumys = np.zeros_like(xnews)
countys = np.zeros_like(xnews)
interstep_decay = np.exp(- 1. / decay_steps)
sum_ys = np.zeros_like(xnews)
count_ys = np.zeros_like(xnews)
for i in range(n):
xnew = xnews[i]
sumy *= interstepdecay
county *= interstepdecay
sum_y *= interstep_decay
count_y *= interstep_decay
while True:
xold = xolds[luoi]
if xold <= xnew:
decay = np.exp(- (xnew - xold) / decay_period)
sumy += decay * yolds[luoi]
county += decay
sum_y += decay * yolds[luoi]
count_y += decay
luoi += 1
else:
break
if luoi >= len(xolds):
break
sumys[i] = sumy
countys[i] = county
return sumys, countys
sum_ys[i] = sum_y
count_ys[i] = count_y
def smooth_uneven(xolds, yolds, low, high, n, decay_steps=1., mode='symmetric'):
import pyximport; pyximport.install(setup_args={"include_dirs":np.get_include()})
# from baselines.common import smooth_helpers #pylint: disable=E0611
xolds = xolds.astype('float64')
yolds = yolds.astype('float64')
if mode == 'causal':
sumys, countys = one_sided_ema(xolds, yolds, low, high, n, decay_steps)
elif mode == 'symmetric':
sumys1, countys1 = one_sided_ema(xolds, yolds, low, high, n, decay_steps)
sumys2, countys2 = one_sided_ema(-xolds[::-1], yolds[::-1], -high, -low, n, decay_steps)
sumys = sumys1 + sumys2[::-1]
countys = countys1 + countys2[::-1]
xs = np.linspace(low, high, n)
ys = sumys / countys
ys[countys < 1e-8] = np.nan
return xs, ys
ys = sum_ys / count_ys
ys[count_ys < low_counts_threshold] = np.nan
def test_smooth():
norig = 100
nup = 300
ndown = 30
xs = np.cumsum(np.random.rand(norig) * 10 / norig)
yclean = np.sin(xs)
ys = yclean + .1 * np.random.randn(yclean.size)
xup, yup = smooth_uneven(xs, ys, xs.min(), xs.max(), nup, decay_steps=nup/ndown)
xdown, ydown = smooth_uneven(xs, ys, xs.min(), xs.max(), ndown, decay_steps=ndown/ndown)
xsame, ysame = smooth_uneven(xs, ys, xs.min(), xs.max(), norig, decay_steps=norig/ndown)
plt.plot(xs, ys, label='orig', marker='x')
plt.plot(xup, yup, label='up', marker='x')
plt.plot(xdown, ydown, label='down', marker='x')
plt.plot(xsame, ysame, label='same', marker='x')
plt.plot(xs, yclean, label='clean', marker='x')
plt.legend()
plt.show()
return xnews, ys, count_ys
def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
'''
perform symmetric EMA (exponential moving average)
smoothing and resampling to an even grid with n points.
Does not do extrapolation, so we assume
xolds[0] <= low && high <= xolds[-1]
Arguments:
xolds: array or list - x values of data. Needs to be sorted in ascending order
yolds: array of list - y values of data. Has to have the same length as xolds
low: float - min value of the new x grid. By default equals to xolds[0]
high: float - max value of the new x grid. By default equals to xolds[-1]
n: int - number of points in new x grid
decay_steps: float - EMA decay factor, expressed in new x grid steps.
low_counts_threshold: float or int
- y values with counts less than this value will be set to NaN
Returns:
tuple sum_ys, count_ys where
xs - array with new x grid
ys - array of EMA of y at each point of the new x grid
count_ys - array of EMA of y counts at each point of the new x grid
'''
xs, ys1, count_ys1 = one_sided_ema(xolds, yolds, low, high, n, decay_steps, low_counts_threshold=0)
_, ys2, count_ys2 = one_sided_ema(-xolds[::-1], yolds[::-1], -high, -low, n, decay_steps, low_counts_threshold=0)
ys2 = ys2[::-1]
count_ys2 = count_ys2[::-1]
count_ys = count_ys1 + count_ys2
ys = (ys1 * count_ys1 + ys2 * count_ys2) / count_ys
ys[count_ys < low_counts_threshold] = np.nan
return xs, ys, count_ys
Result = namedtuple('Result', 'monitor progress dirname metadata')
Result.__new__.__defaults__ = (None,) * len(Result._fields)
@@ -109,14 +155,14 @@ def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, ve
Arguments:
enable_progress: bool - if True, will attempt to load data from progress.csv files (data saved by logger). Default: True
enable_monitor: bool - if True, will attepmt to load data from monitor.csv files (data saved by Monitor environment wrapper). Default: True
verbose: bool - if True, will print out list of directories from which the data is loaded. Default: False
Returns:
List of Result objects with the following fields:
List of Result objects with the following fields:
- dirname - path to the directory data was loaded from
- metadata - run metadata (such as command-line arguments and anything else in metadata.json file
- monitor - if enable_monitor is True, this field contains pandas dataframe with loaded monitor.csv file (or aggregate of all *.monitor.csv files in the directory)
@@ -136,7 +182,7 @@ def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, ve
if set(['metadata.json', 'monitor.json', 'monitor.csv', 'progress.json', 'progress.csv']).intersection(files):
# used to be uncommented, which means do not go deeper than current directory if any of the data files
# are found
# dirs[:] = []
# dirs[:] = []
result = {'dirname' : dirname}
if "metadata.json" in files:
with open(osp.join(dirname, "metadata.json"), "r") as fh:
@@ -162,11 +208,11 @@ def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, ve
except Exception as e:
print('exception loading monitor file in %s: %s'%(dirname, e))
if result.get('monitor') is not None or result.get('progress') is not None:
if result.get('monitor') is not None or result.get('progress') is not None:
allresults.append(Result(**result))
if verbose:
print('successfully loaded %s'%dirname)
if verbose: print('loaded %i results'%len(allresults))
return allresults
@@ -182,24 +228,57 @@ def default_xy_fn(r):
def default_split_fn(r):
import re
# match name between slash and -<digits> at the end of the string
# match name between slash and -<digits> at the end of the string
# (slash in the beginning or -<digits> in the end or either may be missing)
match = re.search(r'[^/-]+(?=(-\d+)?\Z)', r.dirname)
if match:
return match.group(0)
def plot_results(
allresults,
allresults, *,
xy_fn=default_xy_fn,
split_fn=default_split_fn,
group_fn=default_split_fn,
average_group=False,
figsize=None,
legend_outside=False,
resample=0
resample=0,
smooth_step=1.0,
):
'''
plot multiple Results object
Plot multiple Results objects
xy_fn: function Result -> x,y - function that converts results objects into tuple of x and y values.
By default, x is cumsum of episode lengths, and y is episode rewards
split_fn: function Result -> hashable - function that converts results objects into keys to split curves into subpanels by.
That is, the results r for which split_fn(r) is different will be put on different subpanels.
By default, the portion of r.dirname between last / and -<digits> is returned. The subpanels are
stacked vertically in the figure.
group_fn: function Result -> hashable - function that converts results objects into keys to group curves by.
That is, the results r for which group_fn(r) is the same will be put into the same group.
Curves in the same group have the same color (if average_group is False), or averaged over
(if average_group is True). The default value is the same as default value for split_fn
average_group: bool - if True, will average the curves in the same group. The mean of the result is plotted, with lighter
shaded region around corresponding to the standard deviation, and darker shaded region corresponding to
the error of mean estimate (that is, standard deviation over square root of number of samples)
figsize: tuple or None - size of the resulting figure (including subpanels). By default, width is 6 and height is 6 times number of
subpanels.
legend_outside: bool - if True, will place the legend outside of the subpanels.
resample: int - if not zero, size of the uniform grid in x direction to resample onto. Resampling is performed via symmetric
EMA smoothing (see the docstring for symmetric_ema).
Default is zero (no resampling). Note that if average_group is True, resampling is necessary; in that case, default
value is 512.
smooth_step: float - when resampling (i.e. when resample > 0 or average_group is True), use this EMA decay parameter (in units of the new grid step).
See docstrings for decay_steps in symmetric_ema or one_sided_ema functions.
'''
if split_fn is None: split_fn = lambda _ : ''
@@ -217,6 +296,10 @@ def plot_results(
groups = list(set(group_fn(result) for result in allresults))
default_samples = 512
if average_group:
resample = resample or default_samples
for (isplit, sk) in enumerate(sorted(sk2r.keys())):
g2l = {}
g2c = defaultdict(int)
@@ -233,34 +316,39 @@ def plot_results(
gresults[group].append((x,y))
else:
if resample:
x, y = smooth_uneven(x, y, x[0], x[-1], resample)
x, y, counts = symmetric_ema(x, y, x[0], x[-1], resample, decay_steps=smooth_step)
l, = ax.plot(x, y, color=COLORS[groups.index(group) % len(COLORS)])
g2l[group] = l
if average_group:
for group in sorted(groups):
xys = gresults[group]
if not any(xys):
continue
color = COLORS[groups.index(group)]
origxs = [xy[0] for xy in xys]
minxlen = min(map(len, origxs))
def allequal(qs):
return all((q==qs[0]).all() for q in qs[1:])
if resample:
low = 0 # usually right thing to do
low = max(x[0] for x in origxs)
high = min(x[-1] for x in origxs)
usex = np.linspace(low, high, resample)
ys = []
for (x, y) in xys:
ys.append(smooth_uneven(x, y, low, high, resample)[1])
ys.append(symmetric_ema(x, y, low, high, resample, decay_steps=smooth_step)[1])
else:
assert allequal([x[:minxlen] for x in origxs]),\
'If you want to average unevenly sampled data, set resample=<number of samples you want>'
usex = origxs[0]
ys = [xy[1][:minxlen] for xy in xys]
ymean = np.mean(ys, axis=0)
ystderr = np.std(ys, axis=0) / np.sqrt(len(ys))
ystd = np.std(ys, axis=0)
ystderr = ystd / np.sqrt(len(ys))
l, = axarr[isplit][0].plot(usex, ymean, color=color)
g2l[group] = l
ax.fill_between(usex, ymean-ystderr, ymean+ystderr, color=color, alpha=.3)
ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2)
# https://matplotlib.org/users/legend_guide.html
plt.tight_layout()
@@ -282,4 +370,22 @@ def regression_analysis(df):
res = mod.fit()
print(res.summary())
def test_smooth():
norig = 100
nup = 300
ndown = 30
xs = np.cumsum(np.random.rand(norig) * 10 / norig)
yclean = np.sin(xs)
ys = yclean + .1 * np.random.randn(yclean.size)
xup, yup = smooth_uneven(xs, ys, xs.min(), xs.max(), nup, decay_steps=nup/ndown)
xdown, ydown = smooth_uneven(xs, ys, xs.min(), xs.max(), ndown, decay_steps=ndown/ndown)
xsame, ysame = smooth_uneven(xs, ys, xs.min(), xs.max(), norig, decay_steps=norig/ndown)
plt.plot(xs, ys, label='orig', marker='x')
plt.plot(xup, yup, label='up', marker='x')
plt.plot(xdown, ydown, label='down', marker='x')
plt.plot(xsame, ysame, label='same', marker='x')
plt.plot(xs, yclean, label='clean', marker='x')
plt.legend()
plt.show()

View File

@@ -29,6 +29,8 @@ tensorboard --logdir=$OPENAI_LOGDIR
## Loading summaries of the results
If the summary overview provided by tensorboard is not sufficient, and you would like to either access to raw environment episode data, or use complex post-processing notavailable in tensorboard, you can load results into python as [pandas](https://pandas.pydata.org/) dataframes.
The colab notebook with the full version of the code is available [here](https://colab.research.google.com/drive/1Wez1SA9PmNkCoYc8Fvl53bhU3F8OffGm) (use "Open in playground" button to get a runnable version)
For instance, the following snippet:
```python
from baselines.common import plot_util as pu
@@ -66,7 +68,7 @@ plt.plot(np.cumsum(r.monitor.l), pu.smooth(r.monitor.r, radius=10))
We can also get a similar curve by using logger summaries (instead of raw episode data in monitor.csv):
```python
plt.plot(r.progress.total_timesteps, r.progres.eprewmean)
plt.plot(r.progress.total_timesteps, r.progress.eprewmean)
```
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-10-29%20at%205.04.31%20PM.png" width="730">
@@ -103,11 +105,13 @@ The results are split into two groups based on batch size and are plotted on a s
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-10-29%20at%205.53.45%20PM.png" width="700">
Showing all seeds on the same plot may be somewhat hard to comprehend and analyse. We can instead average over all seeds via the following command:
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-11-02%20at%204.42.52%20PM.png" width="720">
## Advanced preprocessing and smoothing
## Plotting: standalone
The lighter shade shows the standard deviation of data, and darker shade -
error in estimate of the mean (that is, standard deviation divided by sqare root of number of seeds)
Note that averaging over seeds requires resampling to a common grid, which, in turn, requires smoothing
(using language of signal processing, we need to do low-pass filtering before resampling to avoid aliasing effects).
You can change the amount of smoothing by adjusting `resample` and `smooth_step` arguments to achieve desired smoothing effect
See the docstring of `plot_util` function for more info.