diff --git a/Dockerfile b/Dockerfile index 49a9c79..12e67be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,8 @@ FROM python:3.6 +RUN apt-get -y update && apt-get -y install ffmpeg # RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv + ENV CODE_DIR /root/code COPY . $CODE_DIR/baselines diff --git a/README.md b/README.md index c5e1fde..e382a8b 100644 --- a/README.md +++ b/README.md @@ -109,17 +109,9 @@ python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num_timesteps=0 -- *NOTE:* At the moment Mujoco training uses VecNormalize wrapper for the environment which is not being saved correctly; so loading the models trained on Mujoco will not work well if the environment is recreated. If necessary, you can work around that by replacing RunningMeanStd by TfRunningMeanStd in [baselines/common/vec_env/vec_normalize.py](baselines/common/vec_env/vec_normalize.py#L12). This way, mean and std of environment normalizing wrapper will be saved in tensorflow variables and included in the model file; however, training is slower that way - hence not including it by default +## Loading and vizualizing learning curves and other training metrics +See [here](docs/viz/viz.md) for instructions on how to load and display the training data. -## Using baselines with TensorBoard -Baselines logger can save data in the TensorBoard format. To do so, set environment variables `OPENAI_LOG_FORMAT` and `OPENAI_LOGDIR`: -```bash -export OPENAI_LOG_FORMAT='stdout,log,csv,tensorboard' # formats are comma-separated, but for tensorboard you only really need the last one -export OPENAI_LOGDIR=path/to/tensorboard/data -``` -And you can now start TensorBoard with: -```bash -tensorboard --logdir=$OPENAI_LOGDIR -``` ## Subpackages - [A2C](baselines/a2c) diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 162e34d..90b9868 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -131,6 +131,8 @@ def common_arg_parser(): parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int) parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float) parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str) + parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int) + parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int) parser.add_argument('--play', default=False, action='store_true') return parser diff --git a/baselines/common/plot_util.py b/baselines/common/plot_util.py new file mode 100644 index 0000000..6f4c272 --- /dev/null +++ b/baselines/common/plot_util.py @@ -0,0 +1,391 @@ +import matplotlib.pyplot as plt +import os.path as osp +import json +import os +import numpy as np +import pandas +from collections import defaultdict, namedtuple +from baselines.bench import monitor +from baselines.logger import read_json, read_csv + +def smooth(y, radius, mode='two_sided', valid_only=False): + ''' + Smooth signal y, where radius is determines the size of the window + + mode='twosided': + average over the window [max(index - radius, 0), min(index + radius, len(y)-1)] + mode='causal': + average over the window [max(index - radius, 0), index] + + valid_only: put nan in entries where the full-sized window is not available + + ''' + assert mode in ('two_sided', 'causal') + if len(y) < 2*radius+1: + return np.ones_like(y) * y.mean() + elif mode == 'two_sided': + convkernel = np.ones(2 * radius+1) + out = np.convolve(y, convkernel,mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same') + if valid_only: + out[:radius] = out[-radius:] = np.nan + elif mode == 'causal': + convkernel = np.ones(radius) + out = np.convolve(y, convkernel,mode='full') / np.convolve(np.ones_like(y), convkernel, mode='full') + out = out[:-radius+1] + if valid_only: + out[:radius] = np.nan + return out + +def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8): + ''' + perform one-sided (causal) EMA (exponential moving average) + smoothing and resampling to an even grid with n points. + Does not do extrapolation, so we assume + xolds[0] <= low && high <= xolds[-1] + + Arguments: + + xolds: array or list - x values of data. Needs to be sorted in ascending order + yolds: array of list - y values of data. Has to have the same length as xolds + + low: float - min value of the new x grid. By default equals to xolds[0] + high: float - max value of the new x grid. By default equals to xolds[-1] + + n: int - number of points in new x grid + + decay_steps: float - EMA decay factor, expressed in new x grid steps. + + low_counts_threshold: float or int + - y values with counts less than this value will be set to NaN + + Returns: + tuple sum_ys, count_ys where + xs - array with new x grid + ys - array of EMA of y at each point of the new x grid + count_ys - array of EMA of y counts at each point of the new x grid + + ''' + + low = xolds[0] if low is None else low + high = xolds[-1] if high is None else high + + assert xolds[0] <= low, 'low = {} < xolds[0] = {} - extrapolation not permitted!'.format(low, xolds[0]) + assert xolds[-1] >= high, 'high = {} > xolds[-1] = {} - extrapolation not permitted!'.format(high, xolds[-1]) + assert len(xolds) == len(yolds), 'length of xolds ({}) and yolds ({}) do not match!'.format(len(xolds), len(yolds)) + + + xolds = xolds.astype('float64') + yolds = yolds.astype('float64') + + luoi = 0 # last unused old index + sum_y = 0. + count_y = 0. + xnews = np.linspace(low, high, n) + decay_period = (high - low) / (n - 1) * decay_steps + interstep_decay = np.exp(- 1. / decay_steps) + sum_ys = np.zeros_like(xnews) + count_ys = np.zeros_like(xnews) + for i in range(n): + xnew = xnews[i] + sum_y *= interstep_decay + count_y *= interstep_decay + while True: + xold = xolds[luoi] + if xold <= xnew: + decay = np.exp(- (xnew - xold) / decay_period) + sum_y += decay * yolds[luoi] + count_y += decay + luoi += 1 + else: + break + if luoi >= len(xolds): + break + sum_ys[i] = sum_y + count_ys[i] = count_y + + ys = sum_ys / count_ys + ys[count_ys < low_counts_threshold] = np.nan + + return xnews, ys, count_ys + +def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8): + ''' + perform symmetric EMA (exponential moving average) + smoothing and resampling to an even grid with n points. + Does not do extrapolation, so we assume + xolds[0] <= low && high <= xolds[-1] + + Arguments: + + xolds: array or list - x values of data. Needs to be sorted in ascending order + yolds: array of list - y values of data. Has to have the same length as xolds + + low: float - min value of the new x grid. By default equals to xolds[0] + high: float - max value of the new x grid. By default equals to xolds[-1] + + n: int - number of points in new x grid + + decay_steps: float - EMA decay factor, expressed in new x grid steps. + + low_counts_threshold: float or int + - y values with counts less than this value will be set to NaN + + Returns: + tuple sum_ys, count_ys where + xs - array with new x grid + ys - array of EMA of y at each point of the new x grid + count_ys - array of EMA of y counts at each point of the new x grid + + ''' + xs, ys1, count_ys1 = one_sided_ema(xolds, yolds, low, high, n, decay_steps, low_counts_threshold=0) + _, ys2, count_ys2 = one_sided_ema(-xolds[::-1], yolds[::-1], -high, -low, n, decay_steps, low_counts_threshold=0) + ys2 = ys2[::-1] + count_ys2 = count_ys2[::-1] + count_ys = count_ys1 + count_ys2 + ys = (ys1 * count_ys1 + ys2 * count_ys2) / count_ys + ys[count_ys < low_counts_threshold] = np.nan + return xs, ys, count_ys + +Result = namedtuple('Result', 'monitor progress dirname metadata') +Result.__new__.__defaults__ = (None,) * len(Result._fields) + +def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, verbose=False): + ''' + load summaries of runs from a list of directories (including subdirectories) + Arguments: + + enable_progress: bool - if True, will attempt to load data from progress.csv files (data saved by logger). Default: True + + enable_monitor: bool - if True, will attempt to load data from monitor.csv files (data saved by Monitor environment wrapper). Default: True + + verbose: bool - if True, will print out list of directories from which the data is loaded. Default: False + + + Returns: + List of Result objects with the following fields: + - dirname - path to the directory data was loaded from + - metadata - run metadata (such as command-line arguments and anything else in metadata.json file + - monitor - if enable_monitor is True, this field contains pandas dataframe with loaded monitor.csv file (or aggregate of all *.monitor.csv files in the directory) + - progress - if enable_progress is True, this field contains pandas dataframe with loaded progress.csv file + ''' + if isinstance(root_dir_or_dirs, str): + rootdirs = [osp.expanduser(root_dir_or_dirs)] + else: + rootdirs = [osp.expanduser(d) for d in root_dir_or_dirs] + allresults = [] + for rootdir in rootdirs: + assert osp.exists(rootdir), "%s doesn't exist"%rootdir + for dirname, dirs, files in os.walk(rootdir): + if '-proc' in dirname: + files[:] = [] + continue + if set(['metadata.json', 'monitor.json', 'monitor.csv', 'progress.json', 'progress.csv']).intersection(files): + # used to be uncommented, which means do not go deeper than current directory if any of the data files + # are found + # dirs[:] = [] + result = {'dirname' : dirname} + if "metadata.json" in files: + with open(osp.join(dirname, "metadata.json"), "r") as fh: + result['metadata'] = json.load(fh) + progjson = osp.join(dirname, "progress.json") + progcsv = osp.join(dirname, "progress.csv") + if enable_progress: + if osp.exists(progjson): + result['progress'] = pandas.DataFrame(read_json(progjson)) + elif osp.exists(progcsv): + try: + result['progress'] = read_csv(progcsv) + except pandas.errors.EmptyDataError: + print('skipping progress file in ', dirname, 'empty data') + else: + if verbose: print('skipping %s: no progress file'%dirname) + + if enable_monitor: + try: + result['monitor'] = pandas.DataFrame(monitor.load_results(dirname)) + except monitor.LoadMonitorResultsError: + print('skipping %s: no monitor files'%dirname) + except Exception as e: + print('exception loading monitor file in %s: %s'%(dirname, e)) + + if result.get('monitor') is not None or result.get('progress') is not None: + allresults.append(Result(**result)) + if verbose: + print('successfully loaded %s'%dirname) + + if verbose: print('loaded %i results'%len(allresults)) + return allresults + +COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink', + 'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise', + 'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue'] + + +def default_xy_fn(r): + x = np.cumsum(r.monitor.l) + y = smooth(r.monitor.r, radius=10) + return x,y + +def default_split_fn(r): + import re + # match name between slash and - at the end of the string + # (slash in the beginning or - in the end or either may be missing) + match = re.search(r'[^/-]+(?=(-\d+)?\Z)', r.dirname) + if match: + return match.group(0) + +def plot_results( + allresults, *, + xy_fn=default_xy_fn, + split_fn=default_split_fn, + group_fn=default_split_fn, + average_group=False, + figsize=None, + legend_outside=False, + resample=0, + smooth_step=1.0, +): + ''' + Plot multiple Results objects + + xy_fn: function Result -> x,y - function that converts results objects into tuple of x and y values. + By default, x is cumsum of episode lengths, and y is episode rewards + + split_fn: function Result -> hashable - function that converts results objects into keys to split curves into sub-panels by. + That is, the results r for which split_fn(r) is different will be put on different sub-panels. + By default, the portion of r.dirname between last / and - is returned. The sub-panels are + stacked vertically in the figure. + + group_fn: function Result -> hashable - function that converts results objects into keys to group curves by. + That is, the results r for which group_fn(r) is the same will be put into the same group. + Curves in the same group have the same color (if average_group is False), or averaged over + (if average_group is True). The default value is the same as default value for split_fn + + average_group: bool - if True, will average the curves in the same group. The mean of the result is plotted, with lighter + shaded region around corresponding to the standard deviation, and darker shaded region corresponding to + the error of mean estimate (that is, standard deviation over square root of number of samples) + + figsize: tuple or None - size of the resulting figure (including sub-panels). By default, width is 6 and height is 6 times number of + sub-panels. + + + legend_outside: bool - if True, will place the legend outside of the sub-panels. + + resample: int - if not zero, size of the uniform grid in x direction to resample onto. Resampling is performed via symmetric + EMA smoothing (see the docstring for symmetric_ema). + Default is zero (no resampling). Note that if average_group is True, resampling is necessary; in that case, default + value is 512. + + smooth_step: float - when resampling (i.e. when resample > 0 or average_group is True), use this EMA decay parameter (in units of the new grid step). + See docstrings for decay_steps in symmetric_ema or one_sided_ema functions. + + ''' + + if split_fn is None: split_fn = lambda _ : '' + if group_fn is None: group_fn = lambda _ : '' + sk2r = defaultdict(list) # splitkey2results + for result in allresults: + splitkey = split_fn(result) + sk2r[splitkey].append(result) + assert len(sk2r) > 0 + assert isinstance(resample, int), "0: don't resample. : that many samples" + nrows = len(sk2r) + ncols = 1 + figsize = figsize or (6, 6 * nrows) + f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize) + + groups = list(set(group_fn(result) for result in allresults)) + + default_samples = 512 + if average_group: + resample = resample or default_samples + + for (isplit, sk) in enumerate(sorted(sk2r.keys())): + g2l = {} + g2c = defaultdict(int) + sresults = sk2r[sk] + gresults = defaultdict(list) + ax = axarr[isplit][0] + for result in sresults: + group = group_fn(result) + g2c[group] += 1 + x, y = xy_fn(result) + if x is None: x = np.arange(len(y)) + x, y = map(np.asarray, (x, y)) + if average_group: + gresults[group].append((x,y)) + else: + if resample: + x, y, counts = symmetric_ema(x, y, x[0], x[-1], resample, decay_steps=smooth_step) + l, = ax.plot(x, y, color=COLORS[groups.index(group) % len(COLORS)]) + g2l[group] = l + if average_group: + for group in sorted(groups): + xys = gresults[group] + if not any(xys): + continue + color = COLORS[groups.index(group)] + origxs = [xy[0] for xy in xys] + minxlen = min(map(len, origxs)) + def allequal(qs): + return all((q==qs[0]).all() for q in qs[1:]) + if resample: + low = max(x[0] for x in origxs) + high = min(x[-1] for x in origxs) + usex = np.linspace(low, high, resample) + ys = [] + for (x, y) in xys: + ys.append(symmetric_ema(x, y, low, high, resample, decay_steps=smooth_step)[1]) + else: + assert allequal([x[:minxlen] for x in origxs]),\ + 'If you want to average unevenly sampled data, set resample=' + usex = origxs[0] + ys = [xy[1][:minxlen] for xy in xys] + ymean = np.mean(ys, axis=0) + ystd = np.std(ys, axis=0) + ystderr = ystd / np.sqrt(len(ys)) + l, = axarr[isplit][0].plot(usex, ymean, color=color) + g2l[group] = l + ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4) + ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2) + + + # https://matplotlib.org/users/legend_guide.html + plt.tight_layout() + if any(g2l.keys()): + ax.legend( + g2l.values(), + ['%s (%i)'%(g, g2c[g]) for g in g2l] if average_group else g2l.keys(), + loc=2 if legend_outside else None, + bbox_to_anchor=(1,1) if legend_outside else None) + ax.set_title(sk) + return f, axarr + +def regression_analysis(df): + xcols = list(df.columns.copy()) + xcols.remove('score') + ycols = ['score'] + import statsmodels.api as sm + mod = sm.OLS(df[ycols], sm.add_constant(df[xcols]), hasconst=False) + res = mod.fit() + print(res.summary()) + +def test_smooth(): + norig = 100 + nup = 300 + ndown = 30 + xs = np.cumsum(np.random.rand(norig) * 10 / norig) + yclean = np.sin(xs) + ys = yclean + .1 * np.random.randn(yclean.size) + xup, yup, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), nup, decay_steps=nup/ndown) + xdown, ydown, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), ndown, decay_steps=ndown/ndown) + xsame, ysame, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), norig, decay_steps=norig/ndown) + plt.plot(xs, ys, label='orig', marker='x') + plt.plot(xup, yup, label='up', marker='x') + plt.plot(xdown, ydown, label='down', marker='x') + plt.plot(xsame, ysame, label='same', marker='x') + plt.plot(xs, yclean, label='clean', marker='x') + plt.legend() + plt.show() + + diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index cb60531..075a139 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -32,6 +32,11 @@ class VecEnv(ABC): """ closed = False viewer = None + + metadata = { + 'render.modes': ['human', 'rgb_array'] + } + def __init__(self, num_envs, observation_space, action_space): self.num_envs = num_envs self.observation_space = observation_space diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index 45f8822..c2b86dd 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -20,7 +20,6 @@ class DummyVecEnv(VecEnv): env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) obs_space = env.observation_space - self.keys, shapes, dtypes = obs_space_info(obs_space) self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } @@ -77,6 +76,6 @@ class DummyVecEnv(VecEnv): def render(self, mode='human'): if self.num_envs == 1: - self.envs[0].render(mode=mode) + return self.envs[0].render(mode=mode) else: - super().render(mode=mode) + return super().render(mode=mode) diff --git a/baselines/common/vec_env/test_video_recorder.py b/baselines/common/vec_env/test_video_recorder.py new file mode 100644 index 0000000..363404a --- /dev/null +++ b/baselines/common/vec_env/test_video_recorder.py @@ -0,0 +1,49 @@ +""" +Tests for asynchronous vectorized environments. +""" + +import gym +import pytest +import os +import glob +import tempfile + +from .dummy_vec_env import DummyVecEnv +from .shmem_vec_env import ShmemVecEnv +from .subproc_vec_env import SubprocVecEnv +from .vec_video_recorder import VecVideoRecorder + +@pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv)) +@pytest.mark.parametrize('num_envs', (1, 4)) +@pytest.mark.parametrize('video_length', (10, 100)) +@pytest.mark.parametrize('video_interval', (1, 50)) +def test_video_recorder(klass, num_envs, video_length, video_interval): + """ + Wrap an existing VecEnv with VevVideoRecorder, + Make (video_interval + video_length + 1) steps, + then check that the file is present + """ + + def make_fn(): + env = gym.make('PongNoFrameskip-v4') + return env + fns = [make_fn for _ in range(num_envs)] + env = klass(fns) + + with tempfile.TemporaryDirectory() as video_path: + env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length) + + env.reset() + for _ in range(video_interval + video_length + 1): + env.step([0] * num_envs) + env.close() + + + recorded_video = glob.glob(os.path.join(video_path, "*.mp4")) + + # first and second step + assert len(recorded_video) == 2 + # Files are not empty + assert all(os.stat(p).st_size != 0 for p in recorded_video) + + diff --git a/baselines/common/vec_env/vec_video_recorder.py b/baselines/common/vec_env/vec_video_recorder.py new file mode 100644 index 0000000..b4e7059 --- /dev/null +++ b/baselines/common/vec_env/vec_video_recorder.py @@ -0,0 +1,89 @@ +import os +from baselines import logger +from baselines.common.vec_env import VecEnvWrapper +from gym.wrappers.monitoring import video_recorder + + +class VecVideoRecorder(VecEnvWrapper): + """ + Wrap VecEnv to record rendered image as mp4 video. + """ + + def __init__(self, venv, directory, record_video_trigger, video_length=200): + """ + # Arguments + venv: VecEnv to wrap + directory: Where to save videos + record_video_trigger: + Function that defines when to start recording. + The function takes the current number of step, + and returns whether we should start recording or not. + video_length: Length of recorded video + """ + + VecEnvWrapper.__init__(self, venv) + self.record_video_trigger = record_video_trigger + self.video_recorder = None + + self.directory = os.path.abspath(directory) + if not os.path.exists(self.directory): os.mkdir(self.directory) + + self.file_prefix = "vecenv" + self.file_infix = '{}'.format(os.getpid()) + self.step_id = 0 + self.video_length = video_length + + self.recording = False + self.recorded_frames = 0 + + def reset(self): + obs = self.venv.reset() + + self.start_video_recorder() + + return obs + + def start_video_recorder(self): + self.close_video_recorder() + + base_path = os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.step_id)) + self.video_recorder = video_recorder.VideoRecorder( + env=self.venv, + base_path=base_path, + metadata={'step_id': self.step_id} + ) + + self.video_recorder.capture_frame() + self.recorded_frames = 1 + self.recording = True + + def _video_enabled(self): + return self.record_video_trigger(self.step_id) + + def step_wait(self): + obs, rews, dones, infos = self.venv.step_wait() + + self.step_id += 1 + if self.recording: + self.video_recorder.capture_frame() + self.recorded_frames += 1 + if self.recorded_frames > self.video_length: + logger.info("Saving video to ", self.video_recorder.path) + self.close_video_recorder() + elif self._video_enabled(): + self.start_video_recorder() + + return obs, rews, dones, infos + + def close_video_recorder(self): + if self.recording: + self.video_recorder.close() + self.recording = False + self.recorded_frames = 0 + + def close(self): + VecEnvWrapper.close(self) + self.close_video_recorder() + + def __del__(self): + self.close() diff --git a/baselines/run.py b/baselines/run.py index 28cf620..c0298f3 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -6,6 +6,7 @@ from collections import defaultdict import tensorflow as tf import numpy as np +from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env from baselines.common.tf_util import get_session @@ -62,6 +63,8 @@ def train(args, extra_args): alg_kwargs.update(extra_args) env = build_env(args) + if args.save_video_interval != 0: + env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length) if args.network: alg_kwargs['network'] = args.network diff --git a/docs/viz/viz.md b/docs/viz/viz.md new file mode 100644 index 0000000..d54ca37 --- /dev/null +++ b/docs/viz/viz.md @@ -0,0 +1,117 @@ +# Loading and visualizing results +In order to compare performance of algorithms, we often would like to visualize learning curves (reward as a function of time steps), or some other auxiliary information about learning +aggregated into a plot. Baselines repo provides tools for doing so in several different ways, depending on the goal. + +## Preliminaries +For all algorithms in baselines summary data is saved into a folder defined by logger. By default, a folder `$TMPDIR/openai--