Compare commits
13 Commits
gdb
...
peterz_mpi
Author | SHA1 | Date | |
---|---|---|---|
|
58801032fc | ||
|
b4a149a75f | ||
|
c248bf9a46 | ||
|
d1f7d12743 | ||
|
f0d49fb67d | ||
|
ef2e7246c9 | ||
|
3e3e2b7998 | ||
|
d00f3bce34 | ||
|
72aa2f1251 | ||
|
ea7a52b652 | ||
|
064c45fa76 | ||
|
6f148fdb0d | ||
|
d96e20ff27 |
12
.travis.yml
12
.travis.yml
@@ -5,10 +5,14 @@ python:
|
||||
services:
|
||||
- docker
|
||||
|
||||
env:
|
||||
- DOCKER_SUFFIX=py36-nompi
|
||||
- DOCKER_SUFFIX=py36-mpi
|
||||
|
||||
install:
|
||||
- pip install flake8
|
||||
- docker build . -t baselines-test
|
||||
- pip install flake8
|
||||
- docker build -f test.dockerfile.${DOCKER_SUFFIX} -t baselines-test .
|
||||
|
||||
script:
|
||||
- flake8 . --show-source --statistics
|
||||
- docker run baselines-test pytest -v --forked .
|
||||
- flake8 . --show-source --statistics
|
||||
- docker run baselines-test pytest -v .
|
||||
|
14
README.md
14
README.md
@@ -1,5 +1,3 @@
|
||||
**Status:** Active (under active development, breaking changes may occur)
|
||||
|
||||
<img src="data/logo.jpg" width=25% align="right" /> [](https://travis-ci.org/openai/baselines)
|
||||
|
||||
# Baselines
|
||||
@@ -111,9 +109,17 @@ python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num_timesteps=0 --
|
||||
|
||||
*NOTE:* At the moment Mujoco training uses VecNormalize wrapper for the environment which is not being saved correctly; so loading the models trained on Mujoco will not work well if the environment is recreated. If necessary, you can work around that by replacing RunningMeanStd by TfRunningMeanStd in [baselines/common/vec_env/vec_normalize.py](baselines/common/vec_env/vec_normalize.py#L12). This way, mean and std of environment normalizing wrapper will be saved in tensorflow variables and included in the model file; however, training is slower that way - hence not including it by default
|
||||
|
||||
## Loading and vizualizing learning curves and other training metrics
|
||||
See [here](docs/viz/viz.ipynb) for instructions on how to load and display the training data.
|
||||
|
||||
## Using baselines with TensorBoard
|
||||
Baselines logger can save data in the TensorBoard format. To do so, set environment variables `OPENAI_LOG_FORMAT` and `OPENAI_LOGDIR`:
|
||||
```bash
|
||||
export OPENAI_LOG_FORMAT='stdout,log,csv,tensorboard' # formats are comma-separated, but for tensorboard you only really need the last one
|
||||
export OPENAI_LOGDIR=path/to/tensorboard/data
|
||||
```
|
||||
And you can now start TensorBoard with:
|
||||
```bash
|
||||
tensorboard --logdir=$OPENAI_LOGDIR
|
||||
```
|
||||
## Subpackages
|
||||
|
||||
- [A2C](baselines/a2c)
|
||||
|
@@ -37,6 +37,9 @@ class Runner(AbstractEnvRunner):
|
||||
obs, rewards, dones, _ = self.env.step(actions)
|
||||
self.states = states
|
||||
self.dones = dones
|
||||
for n, done in enumerate(dones):
|
||||
if done:
|
||||
self.obs[n] = self.obs[n]*0
|
||||
self.obs = obs
|
||||
mb_rewards.append(rewards)
|
||||
mb_dones.append(self.dones)
|
||||
|
@@ -75,8 +75,8 @@ class Model(object):
|
||||
train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape)
|
||||
with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE):
|
||||
|
||||
step_model = policy(nbatch=nenvs, nsteps=1, observ_placeholder=step_ob_placeholder, sess=sess)
|
||||
train_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess)
|
||||
train_model = policy(observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
|
||||
|
||||
params = find_trainable_variables("acer_model")
|
||||
@@ -94,7 +94,7 @@ class Model(object):
|
||||
return v
|
||||
|
||||
with tf.variable_scope("acer_model", custom_getter=custom_getter, reuse=True):
|
||||
polyak_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
polyak_model = policy(observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
|
||||
# Notation: (var) = batch variable, (var)s = seqeuence variable, (var)_i = variable index by action at step i
|
||||
|
||||
|
@@ -156,10 +156,9 @@ register_benchmark({
|
||||
|
||||
# HER DDPG
|
||||
|
||||
_fetch_tasks = ['FetchReach-v1', 'FetchPush-v1', 'FetchSlide-v1']
|
||||
register_benchmark({
|
||||
'name': 'Fetch1M',
|
||||
'description': 'Fetch* benchmarks for 1M timesteps',
|
||||
'tasks': [{'trials': 6, 'env_id': env_id, 'num_timesteps': int(1e6)} for env_id in _fetch_tasks]
|
||||
'name': 'HerDdpg',
|
||||
'description': 'Smoke-test only benchmark of HER',
|
||||
'tasks': [{'trials': 1, 'env_id': 'FetchReach-v1'}]
|
||||
})
|
||||
|
||||
|
@@ -72,8 +72,8 @@ class EpisodicLifeEnv(gym.Wrapper):
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
||||
# so it's important to keep lives > 0, so that we only reset once
|
||||
# for Qbert sometimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
@@ -129,26 +129,18 @@ class ClipRewardEnv(gym.RewardWrapper):
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env, width=84, height=84, grayscale=True):
|
||||
def __init__(self, env):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.grayscale = grayscale
|
||||
if self.grayscale:
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 1), dtype=np.uint8)
|
||||
else:
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 3), dtype=np.uint8)
|
||||
self.width = 84
|
||||
self.height = 84
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 1), dtype=np.uint8)
|
||||
|
||||
def observation(self, frame):
|
||||
if self.grayscale:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||
if self.grayscale:
|
||||
frame = np.expand_dims(frame, -1)
|
||||
return frame
|
||||
return frame[:, :, None]
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
@@ -164,7 +156,7 @@ class FrameStack(gym.Wrapper):
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
||||
|
||||
def reset(self):
|
||||
ob = self.env.reset()
|
||||
@@ -205,7 +197,7 @@ class LazyFrames(object):
|
||||
|
||||
def _force(self):
|
||||
if self._out is None:
|
||||
self._out = np.concatenate(self._frames, axis=-1)
|
||||
self._out = np.concatenate(self._frames, axis=2)
|
||||
self._frames = None
|
||||
return self._out
|
||||
|
||||
|
@@ -18,50 +18,33 @@ from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common import retro_wrappers
|
||||
|
||||
def make_vec_env(env_id, env_type, num_env, seed,
|
||||
wrapper_kwargs=None,
|
||||
start_index=0,
|
||||
reward_scale=1.0,
|
||||
flatten_dict_observations=True,
|
||||
gamestate=None,
|
||||
initializer=None,
|
||||
env_kwargs=None,
|
||||
force_dummy=False):
|
||||
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None):
|
||||
"""
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||
"""
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
seed = seed + 10000 * mpi_rank if seed is not None else None
|
||||
logger_dir = logger.get_dir()
|
||||
def make_thunk(rank, initializer=None):
|
||||
def make_thunk(rank):
|
||||
return lambda: make_env(
|
||||
env_id=env_id,
|
||||
env_type=env_type,
|
||||
mpi_rank=mpi_rank,
|
||||
subrank=rank,
|
||||
subrank = rank,
|
||||
seed=seed,
|
||||
reward_scale=reward_scale,
|
||||
gamestate=gamestate,
|
||||
flatten_dict_observations=flatten_dict_observations,
|
||||
wrapper_kwargs=wrapper_kwargs,
|
||||
logger_dir=logger_dir,
|
||||
initializer=initializer,
|
||||
env_kwargs=env_kwargs,
|
||||
wrapper_kwargs=wrapper_kwargs
|
||||
)
|
||||
|
||||
set_global_seeds(seed)
|
||||
if not force_dummy and num_env > 1:
|
||||
return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)])
|
||||
if num_env > 1:
|
||||
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
|
||||
else:
|
||||
return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)])
|
||||
return DummyVecEnv([make_thunk(start_index)])
|
||||
|
||||
|
||||
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None, initializer=None, env_kwargs=None):
|
||||
if initializer is not None:
|
||||
initializer(mpi_rank=mpi_rank, subrank=subrank)
|
||||
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
if env_type == 'atari':
|
||||
env = make_atari(env_id)
|
||||
elif env_type == 'retro':
|
||||
@@ -69,26 +52,20 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
|
||||
gamestate = gamestate or retro.State.DEFAULT
|
||||
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
|
||||
else:
|
||||
env = gym.make(env_id, **(env_kwargs or {}))
|
||||
|
||||
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
|
||||
keys = env.observation_space.spaces.keys()
|
||||
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
|
||||
env = gym.make(env_id)
|
||||
|
||||
env.seed(seed + subrank if seed is not None else None)
|
||||
env = Monitor(env,
|
||||
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
|
||||
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
|
||||
if env_type == 'atari':
|
||||
env = wrap_deepmind(env, **wrapper_kwargs)
|
||||
elif env_type == 'retro':
|
||||
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
|
||||
return wrap_deepmind(env, **wrapper_kwargs)
|
||||
elif reward_scale != 1:
|
||||
return retro_wrappers.RewardScaler(env, reward_scale)
|
||||
else:
|
||||
return env
|
||||
|
||||
if reward_scale != 1:
|
||||
env = retro_wrappers.RewardScaler(env, reward_scale)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
||||
@@ -144,7 +121,6 @@ def common_arg_parser():
|
||||
"""
|
||||
parser = arg_parser()
|
||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
||||
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
||||
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
||||
@@ -153,10 +129,7 @@ def common_arg_parser():
|
||||
parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
|
||||
parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
|
||||
parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
|
||||
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
||||
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
||||
parser.add_argument('--play', default=False, action='store_true')
|
||||
parser.add_argument('--extra_import', help='Extra module to import to access external environments', type=str, default=None)
|
||||
return parser
|
||||
|
||||
def robotics_arg_parser():
|
||||
|
@@ -62,7 +62,7 @@ class CategoricalPdType(PdType):
|
||||
def pdclass(self):
|
||||
return CategoricalPd
|
||||
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
||||
pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
|
||||
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
|
||||
return self.pdfromflat(pdparam), pdparam
|
||||
|
||||
def param_shape(self):
|
||||
@@ -75,15 +75,14 @@ class CategoricalPdType(PdType):
|
||||
|
||||
class MultiCategoricalPdType(PdType):
|
||||
def __init__(self, nvec):
|
||||
self.ncats = nvec.astype('int32')
|
||||
assert (self.ncats > 0).all()
|
||||
self.ncats = nvec
|
||||
def pdclass(self):
|
||||
return MultiCategoricalPd
|
||||
def pdfromflat(self, flat):
|
||||
return MultiCategoricalPd(self.ncats, flat)
|
||||
|
||||
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
|
||||
pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
|
||||
pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
|
||||
return self.pdfromflat(pdparam), pdparam
|
||||
|
||||
def param_shape(self):
|
||||
@@ -100,7 +99,7 @@ class DiagGaussianPdType(PdType):
|
||||
return DiagGaussianPd
|
||||
|
||||
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
||||
mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
|
||||
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
|
||||
return self.pdfromflat(pdparam), mean
|
||||
@@ -124,7 +123,7 @@ class BernoulliPdType(PdType):
|
||||
def sample_dtype(self):
|
||||
return tf.int32
|
||||
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
||||
pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
return self.pdfromflat(pdparam), pdparam
|
||||
|
||||
# WRONG SECOND DERIVATIVES
|
||||
@@ -346,9 +345,3 @@ def validate_probtype(probtype, pdparam):
|
||||
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
|
||||
print('ok on', probtype, pdparam)
|
||||
|
||||
|
||||
def _matching_fc(tensor, name, size, init_scale, init_bias):
|
||||
if tensor.shape[-1] == size:
|
||||
return tensor
|
||||
else:
|
||||
return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)
|
||||
|
@@ -1,404 +0,0 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import os.path as osp
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas
|
||||
from collections import defaultdict, namedtuple
|
||||
from baselines.bench import monitor
|
||||
from baselines.logger import read_json, read_csv
|
||||
|
||||
def smooth(y, radius, mode='two_sided', valid_only=False):
|
||||
'''
|
||||
Smooth signal y, where radius is determines the size of the window
|
||||
|
||||
mode='twosided':
|
||||
average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
|
||||
mode='causal':
|
||||
average over the window [max(index - radius, 0), index]
|
||||
|
||||
valid_only: put nan in entries where the full-sized window is not available
|
||||
|
||||
'''
|
||||
assert mode in ('two_sided', 'causal')
|
||||
if len(y) < 2*radius+1:
|
||||
return np.ones_like(y) * y.mean()
|
||||
elif mode == 'two_sided':
|
||||
convkernel = np.ones(2 * radius+1)
|
||||
out = np.convolve(y, convkernel,mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same')
|
||||
if valid_only:
|
||||
out[:radius] = out[-radius:] = np.nan
|
||||
elif mode == 'causal':
|
||||
convkernel = np.ones(radius)
|
||||
out = np.convolve(y, convkernel,mode='full') / np.convolve(np.ones_like(y), convkernel, mode='full')
|
||||
out = out[:-radius+1]
|
||||
if valid_only:
|
||||
out[:radius] = np.nan
|
||||
return out
|
||||
|
||||
def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
|
||||
'''
|
||||
perform one-sided (causal) EMA (exponential moving average)
|
||||
smoothing and resampling to an even grid with n points.
|
||||
Does not do extrapolation, so we assume
|
||||
xolds[0] <= low && high <= xolds[-1]
|
||||
|
||||
Arguments:
|
||||
|
||||
xolds: array or list - x values of data. Needs to be sorted in ascending order
|
||||
yolds: array of list - y values of data. Has to have the same length as xolds
|
||||
|
||||
low: float - min value of the new x grid. By default equals to xolds[0]
|
||||
high: float - max value of the new x grid. By default equals to xolds[-1]
|
||||
|
||||
n: int - number of points in new x grid
|
||||
|
||||
decay_steps: float - EMA decay factor, expressed in new x grid steps.
|
||||
|
||||
low_counts_threshold: float or int
|
||||
- y values with counts less than this value will be set to NaN
|
||||
|
||||
Returns:
|
||||
tuple sum_ys, count_ys where
|
||||
xs - array with new x grid
|
||||
ys - array of EMA of y at each point of the new x grid
|
||||
count_ys - array of EMA of y counts at each point of the new x grid
|
||||
|
||||
'''
|
||||
|
||||
low = xolds[0] if low is None else low
|
||||
high = xolds[-1] if high is None else high
|
||||
|
||||
assert xolds[0] <= low, 'low = {} < xolds[0] = {} - extrapolation not permitted!'.format(low, xolds[0])
|
||||
assert xolds[-1] >= high, 'high = {} > xolds[-1] = {} - extrapolation not permitted!'.format(high, xolds[-1])
|
||||
assert len(xolds) == len(yolds), 'length of xolds ({}) and yolds ({}) do not match!'.format(len(xolds), len(yolds))
|
||||
|
||||
|
||||
xolds = xolds.astype('float64')
|
||||
yolds = yolds.astype('float64')
|
||||
|
||||
luoi = 0 # last unused old index
|
||||
sum_y = 0.
|
||||
count_y = 0.
|
||||
xnews = np.linspace(low, high, n)
|
||||
decay_period = (high - low) / (n - 1) * decay_steps
|
||||
interstep_decay = np.exp(- 1. / decay_steps)
|
||||
sum_ys = np.zeros_like(xnews)
|
||||
count_ys = np.zeros_like(xnews)
|
||||
for i in range(n):
|
||||
xnew = xnews[i]
|
||||
sum_y *= interstep_decay
|
||||
count_y *= interstep_decay
|
||||
while True:
|
||||
xold = xolds[luoi]
|
||||
if xold <= xnew:
|
||||
decay = np.exp(- (xnew - xold) / decay_period)
|
||||
sum_y += decay * yolds[luoi]
|
||||
count_y += decay
|
||||
luoi += 1
|
||||
else:
|
||||
break
|
||||
if luoi >= len(xolds):
|
||||
break
|
||||
sum_ys[i] = sum_y
|
||||
count_ys[i] = count_y
|
||||
|
||||
ys = sum_ys / count_ys
|
||||
ys[count_ys < low_counts_threshold] = np.nan
|
||||
|
||||
return xnews, ys, count_ys
|
||||
|
||||
def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
|
||||
'''
|
||||
perform symmetric EMA (exponential moving average)
|
||||
smoothing and resampling to an even grid with n points.
|
||||
Does not do extrapolation, so we assume
|
||||
xolds[0] <= low && high <= xolds[-1]
|
||||
|
||||
Arguments:
|
||||
|
||||
xolds: array or list - x values of data. Needs to be sorted in ascending order
|
||||
yolds: array of list - y values of data. Has to have the same length as xolds
|
||||
|
||||
low: float - min value of the new x grid. By default equals to xolds[0]
|
||||
high: float - max value of the new x grid. By default equals to xolds[-1]
|
||||
|
||||
n: int - number of points in new x grid
|
||||
|
||||
decay_steps: float - EMA decay factor, expressed in new x grid steps.
|
||||
|
||||
low_counts_threshold: float or int
|
||||
- y values with counts less than this value will be set to NaN
|
||||
|
||||
Returns:
|
||||
tuple sum_ys, count_ys where
|
||||
xs - array with new x grid
|
||||
ys - array of EMA of y at each point of the new x grid
|
||||
count_ys - array of EMA of y counts at each point of the new x grid
|
||||
|
||||
'''
|
||||
xs, ys1, count_ys1 = one_sided_ema(xolds, yolds, low, high, n, decay_steps, low_counts_threshold=0)
|
||||
_, ys2, count_ys2 = one_sided_ema(-xolds[::-1], yolds[::-1], -high, -low, n, decay_steps, low_counts_threshold=0)
|
||||
ys2 = ys2[::-1]
|
||||
count_ys2 = count_ys2[::-1]
|
||||
count_ys = count_ys1 + count_ys2
|
||||
ys = (ys1 * count_ys1 + ys2 * count_ys2) / count_ys
|
||||
ys[count_ys < low_counts_threshold] = np.nan
|
||||
return xs, ys, count_ys
|
||||
|
||||
Result = namedtuple('Result', 'monitor progress dirname metadata')
|
||||
Result.__new__.__defaults__ = (None,) * len(Result._fields)
|
||||
|
||||
def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, verbose=False):
|
||||
'''
|
||||
load summaries of runs from a list of directories (including subdirectories)
|
||||
Arguments:
|
||||
|
||||
enable_progress: bool - if True, will attempt to load data from progress.csv files (data saved by logger). Default: True
|
||||
|
||||
enable_monitor: bool - if True, will attempt to load data from monitor.csv files (data saved by Monitor environment wrapper). Default: True
|
||||
|
||||
verbose: bool - if True, will print out list of directories from which the data is loaded. Default: False
|
||||
|
||||
|
||||
Returns:
|
||||
List of Result objects with the following fields:
|
||||
- dirname - path to the directory data was loaded from
|
||||
- metadata - run metadata (such as command-line arguments and anything else in metadata.json file
|
||||
- monitor - if enable_monitor is True, this field contains pandas dataframe with loaded monitor.csv file (or aggregate of all *.monitor.csv files in the directory)
|
||||
- progress - if enable_progress is True, this field contains pandas dataframe with loaded progress.csv file
|
||||
'''
|
||||
import re
|
||||
if isinstance(root_dir_or_dirs, str):
|
||||
rootdirs = [osp.expanduser(root_dir_or_dirs)]
|
||||
else:
|
||||
rootdirs = [osp.expanduser(d) for d in root_dir_or_dirs]
|
||||
allresults = []
|
||||
for rootdir in rootdirs:
|
||||
assert osp.exists(rootdir), "%s doesn't exist"%rootdir
|
||||
for dirname, dirs, files in os.walk(rootdir):
|
||||
if '-proc' in dirname:
|
||||
files[:] = []
|
||||
continue
|
||||
monitor_re = re.compile(r'(\d+\.)?(\d+\.)?monitor\.csv')
|
||||
if set(['metadata.json', 'monitor.json', 'progress.json', 'progress.csv']).intersection(files) or \
|
||||
any([f for f in files if monitor_re.match(f)]): # also match monitor files like 0.1.monitor.csv
|
||||
# used to be uncommented, which means do not go deeper than current directory if any of the data files
|
||||
# are found
|
||||
# dirs[:] = []
|
||||
result = {'dirname' : dirname}
|
||||
if "metadata.json" in files:
|
||||
with open(osp.join(dirname, "metadata.json"), "r") as fh:
|
||||
result['metadata'] = json.load(fh)
|
||||
progjson = osp.join(dirname, "progress.json")
|
||||
progcsv = osp.join(dirname, "progress.csv")
|
||||
if enable_progress:
|
||||
if osp.exists(progjson):
|
||||
result['progress'] = pandas.DataFrame(read_json(progjson))
|
||||
elif osp.exists(progcsv):
|
||||
try:
|
||||
result['progress'] = read_csv(progcsv)
|
||||
except pandas.errors.EmptyDataError:
|
||||
print('skipping progress file in ', dirname, 'empty data')
|
||||
else:
|
||||
if verbose: print('skipping %s: no progress file'%dirname)
|
||||
|
||||
if enable_monitor:
|
||||
try:
|
||||
result['monitor'] = pandas.DataFrame(monitor.load_results(dirname))
|
||||
except monitor.LoadMonitorResultsError:
|
||||
print('skipping %s: no monitor files'%dirname)
|
||||
except Exception as e:
|
||||
print('exception loading monitor file in %s: %s'%(dirname, e))
|
||||
|
||||
if result.get('monitor') is not None or result.get('progress') is not None:
|
||||
allresults.append(Result(**result))
|
||||
if verbose:
|
||||
print('successfully loaded %s'%dirname)
|
||||
|
||||
if verbose: print('loaded %i results'%len(allresults))
|
||||
return allresults
|
||||
|
||||
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
|
||||
'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
|
||||
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue']
|
||||
|
||||
|
||||
def default_xy_fn(r):
|
||||
x = np.cumsum(r.monitor.l)
|
||||
y = smooth(r.monitor.r, radius=10)
|
||||
return x,y
|
||||
|
||||
def default_split_fn(r):
|
||||
import re
|
||||
# match name between slash and -<digits> at the end of the string
|
||||
# (slash in the beginning or -<digits> in the end or either may be missing)
|
||||
match = re.search(r'[^/-]+(?=(-\d+)?\Z)', r.dirname)
|
||||
if match:
|
||||
return match.group(0)
|
||||
|
||||
def plot_results(
|
||||
allresults, *,
|
||||
xy_fn=default_xy_fn,
|
||||
split_fn=default_split_fn,
|
||||
group_fn=default_split_fn,
|
||||
average_group=False,
|
||||
shaded_std=True,
|
||||
shaded_err=True,
|
||||
figsize=None,
|
||||
legend_outside=False,
|
||||
resample=0,
|
||||
smooth_step=1.0,
|
||||
):
|
||||
'''
|
||||
Plot multiple Results objects
|
||||
|
||||
xy_fn: function Result -> x,y - function that converts results objects into tuple of x and y values.
|
||||
By default, x is cumsum of episode lengths, and y is episode rewards
|
||||
|
||||
split_fn: function Result -> hashable - function that converts results objects into keys to split curves into sub-panels by.
|
||||
That is, the results r for which split_fn(r) is different will be put on different sub-panels.
|
||||
By default, the portion of r.dirname between last / and -<digits> is returned. The sub-panels are
|
||||
stacked vertically in the figure.
|
||||
|
||||
group_fn: function Result -> hashable - function that converts results objects into keys to group curves by.
|
||||
That is, the results r for which group_fn(r) is the same will be put into the same group.
|
||||
Curves in the same group have the same color (if average_group is False), or averaged over
|
||||
(if average_group is True). The default value is the same as default value for split_fn
|
||||
|
||||
average_group: bool - if True, will average the curves in the same group and plot the mean. Enables resampling
|
||||
(if resample = 0, will use 512 steps)
|
||||
|
||||
shaded_std: bool - if True (default), the shaded region corresponding to standard deviation of the group of curves will be
|
||||
shown (only applicable if average_group = True)
|
||||
|
||||
shaded_err: bool - if True (default), the shaded region corresponding to error in mean estimate of the group of curves
|
||||
(that is, standard deviation divided by square root of number of curves) will be
|
||||
shown (only applicable if average_group = True)
|
||||
|
||||
figsize: tuple or None - size of the resulting figure (including sub-panels). By default, width is 6 and height is 6 times number of
|
||||
sub-panels.
|
||||
|
||||
|
||||
legend_outside: bool - if True, will place the legend outside of the sub-panels.
|
||||
|
||||
resample: int - if not zero, size of the uniform grid in x direction to resample onto. Resampling is performed via symmetric
|
||||
EMA smoothing (see the docstring for symmetric_ema).
|
||||
Default is zero (no resampling). Note that if average_group is True, resampling is necessary; in that case, default
|
||||
value is 512.
|
||||
|
||||
smooth_step: float - when resampling (i.e. when resample > 0 or average_group is True), use this EMA decay parameter (in units of the new grid step).
|
||||
See docstrings for decay_steps in symmetric_ema or one_sided_ema functions.
|
||||
|
||||
'''
|
||||
|
||||
if split_fn is None: split_fn = lambda _ : ''
|
||||
if group_fn is None: group_fn = lambda _ : ''
|
||||
sk2r = defaultdict(list) # splitkey2results
|
||||
for result in allresults:
|
||||
splitkey = split_fn(result)
|
||||
sk2r[splitkey].append(result)
|
||||
assert len(sk2r) > 0
|
||||
assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
|
||||
nrows = len(sk2r)
|
||||
ncols = 1
|
||||
figsize = figsize or (6, 6 * nrows)
|
||||
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
|
||||
|
||||
groups = list(set(group_fn(result) for result in allresults))
|
||||
|
||||
default_samples = 512
|
||||
if average_group:
|
||||
resample = resample or default_samples
|
||||
|
||||
for (isplit, sk) in enumerate(sorted(sk2r.keys())):
|
||||
g2l = {}
|
||||
g2c = defaultdict(int)
|
||||
sresults = sk2r[sk]
|
||||
gresults = defaultdict(list)
|
||||
ax = axarr[isplit][0]
|
||||
for result in sresults:
|
||||
group = group_fn(result)
|
||||
g2c[group] += 1
|
||||
x, y = xy_fn(result)
|
||||
if x is None: x = np.arange(len(y))
|
||||
x, y = map(np.asarray, (x, y))
|
||||
if average_group:
|
||||
gresults[group].append((x,y))
|
||||
else:
|
||||
if resample:
|
||||
x, y, counts = symmetric_ema(x, y, x[0], x[-1], resample, decay_steps=smooth_step)
|
||||
l, = ax.plot(x, y, color=COLORS[groups.index(group) % len(COLORS)])
|
||||
g2l[group] = l
|
||||
if average_group:
|
||||
for group in sorted(groups):
|
||||
xys = gresults[group]
|
||||
if not any(xys):
|
||||
continue
|
||||
color = COLORS[groups.index(group) % len(COLORS)]
|
||||
origxs = [xy[0] for xy in xys]
|
||||
minxlen = min(map(len, origxs))
|
||||
def allequal(qs):
|
||||
return all((q==qs[0]).all() for q in qs[1:])
|
||||
if resample:
|
||||
low = max(x[0] for x in origxs)
|
||||
high = min(x[-1] for x in origxs)
|
||||
usex = np.linspace(low, high, resample)
|
||||
ys = []
|
||||
for (x, y) in xys:
|
||||
ys.append(symmetric_ema(x, y, low, high, resample, decay_steps=smooth_step)[1])
|
||||
else:
|
||||
assert allequal([x[:minxlen] for x in origxs]),\
|
||||
'If you want to average unevenly sampled data, set resample=<number of samples you want>'
|
||||
usex = origxs[0]
|
||||
ys = [xy[1][:minxlen] for xy in xys]
|
||||
ymean = np.mean(ys, axis=0)
|
||||
ystd = np.std(ys, axis=0)
|
||||
ystderr = ystd / np.sqrt(len(ys))
|
||||
l, = axarr[isplit][0].plot(usex, ymean, color=color)
|
||||
g2l[group] = l
|
||||
if shaded_err:
|
||||
ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
|
||||
if shaded_std:
|
||||
ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2)
|
||||
|
||||
|
||||
# https://matplotlib.org/users/legend_guide.html
|
||||
plt.tight_layout()
|
||||
if any(g2l.keys()):
|
||||
ax.legend(
|
||||
g2l.values(),
|
||||
['%s (%i)'%(g, g2c[g]) for g in g2l] if average_group else g2l.keys(),
|
||||
loc=2 if legend_outside else None,
|
||||
bbox_to_anchor=(1,1) if legend_outside else None)
|
||||
ax.set_title(sk)
|
||||
return f, axarr
|
||||
|
||||
def regression_analysis(df):
|
||||
xcols = list(df.columns.copy())
|
||||
xcols.remove('score')
|
||||
ycols = ['score']
|
||||
import statsmodels.api as sm
|
||||
mod = sm.OLS(df[ycols], sm.add_constant(df[xcols]), hasconst=False)
|
||||
res = mod.fit()
|
||||
print(res.summary())
|
||||
|
||||
def test_smooth():
|
||||
norig = 100
|
||||
nup = 300
|
||||
ndown = 30
|
||||
xs = np.cumsum(np.random.rand(norig) * 10 / norig)
|
||||
yclean = np.sin(xs)
|
||||
ys = yclean + .1 * np.random.randn(yclean.size)
|
||||
xup, yup, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), nup, decay_steps=nup/ndown)
|
||||
xdown, ydown, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), ndown, decay_steps=ndown/ndown)
|
||||
xsame, ysame, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), norig, decay_steps=norig/ndown)
|
||||
plt.plot(xs, ys, label='orig', marker='x')
|
||||
plt.plot(xup, yup, label='up', marker='x')
|
||||
plt.plot(xdown, ydown, label='down', marker='x')
|
||||
plt.plot(xsame, ysame, label='same', marker='x')
|
||||
plt.plot(xs, yclean, label='clean', marker='x')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
@@ -132,8 +132,10 @@ class MovieRecord(gym.Wrapper):
|
||||
self.epcount = 0
|
||||
def reset(self):
|
||||
if self.epcount % self.k == 0:
|
||||
print('saving movie this episode', self.savedir)
|
||||
self.env.unwrapped.movie_path = self.savedir
|
||||
else:
|
||||
print('not saving this episode')
|
||||
self.env.unwrapped.movie_path = None
|
||||
self.env.unwrapped.movie = None
|
||||
self.epcount += 1
|
||||
|
@@ -1,39 +0,0 @@
|
||||
import pytest
|
||||
import gym
|
||||
|
||||
from baselines.run import get_learn_function
|
||||
from baselines.common.tests.util import reward_per_episode_test
|
||||
|
||||
pytest.importorskip('mujoco_py')
|
||||
|
||||
common_kwargs = dict(
|
||||
network='mlp',
|
||||
seed=0,
|
||||
)
|
||||
|
||||
learn_kwargs = {
|
||||
'her': dict(total_timesteps=2000)
|
||||
}
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alg", learn_kwargs.keys())
|
||||
def test_fetchreach(alg):
|
||||
'''
|
||||
Test if the algorithm (with an mlp policy)
|
||||
can learn the FetchReach task
|
||||
'''
|
||||
|
||||
kwargs = common_kwargs.copy()
|
||||
kwargs.update(learn_kwargs[alg])
|
||||
|
||||
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
|
||||
def env_fn():
|
||||
|
||||
env = gym.make('FetchReach-v1')
|
||||
env.seed(0)
|
||||
return env
|
||||
|
||||
reward_per_episode_test(env_fn, learn_fn, -15)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fetchreach('her')
|
@@ -103,9 +103,9 @@ def test_coexistence(learn_fn, network_fn):
|
||||
kwargs.update(learn_kwargs[learn_fn])
|
||||
|
||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
model1 = learn(seed=1)
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
model2 = learn(seed=2)
|
||||
|
||||
model1.step(env.observation_space.sample())
|
||||
|
@@ -18,9 +18,7 @@ def test_function():
|
||||
initialize()
|
||||
|
||||
assert lin(2) == 6
|
||||
assert lin(x=3) == 9
|
||||
assert lin(2, 2) == 10
|
||||
assert lin(x=2, y=3) == 12
|
||||
|
||||
|
||||
def test_multikwargs():
|
||||
|
@@ -63,7 +63,7 @@ def rollout(env, model, n_trials):
|
||||
|
||||
for i in range(n_trials):
|
||||
obs = env.reset()
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
state = model.initial_state
|
||||
episode_rew = []
|
||||
episode_actions = []
|
||||
episode_obs = []
|
||||
|
@@ -165,10 +165,6 @@ def function(inputs, outputs, updates=None, givens=None):
|
||||
outputs: [tf.Variable] or tf.Variable
|
||||
list of outputs or a single output to be returned from function. Returned
|
||||
value will also have the same shape.
|
||||
updates: [tf.Operation] or tf.Operation
|
||||
list of update functions or single update function that will be run whenever
|
||||
the function is called. The return is ignored.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, list):
|
||||
return _Function(inputs, outputs, updates, givens=givens)
|
||||
@@ -186,7 +182,6 @@ class _Function(object):
|
||||
if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
|
||||
assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
|
||||
self.inputs = inputs
|
||||
self.input_names = {inp.name.split("/")[-1].split(":")[0]: inp for inp in inputs}
|
||||
updates = updates or []
|
||||
self.update_group = tf.group(*updates)
|
||||
self.outputs_update = list(outputs) + [self.update_group]
|
||||
@@ -198,17 +193,15 @@ class _Function(object):
|
||||
else:
|
||||
feed_dict[inpt] = adjust_shape(inpt, value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert len(args) + len(kwargs) <= len(self.inputs), "Too many arguments provided"
|
||||
def __call__(self, *args):
|
||||
assert len(args) <= len(self.inputs), "Too many arguments provided"
|
||||
feed_dict = {}
|
||||
# Update feed dict with givens.
|
||||
for inpt in self.givens:
|
||||
feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt]))
|
||||
# Update the args
|
||||
for inpt, value in zip(self.inputs, args):
|
||||
self._feed_input(feed_dict, inpt, value)
|
||||
for inpt_name, value in kwargs.items():
|
||||
self._feed_input(feed_dict, self.input_names[inpt_name], value)
|
||||
# Update feed dict with givens.
|
||||
for inpt in self.givens:
|
||||
feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt]))
|
||||
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
||||
return results
|
||||
|
||||
@@ -340,7 +333,7 @@ def save_state(fname, sess=None):
|
||||
|
||||
def save_variables(save_path, variables=None, sess=None):
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
variables = variables or tf.trainable_variables()
|
||||
|
||||
ps = sess.run(variables)
|
||||
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
||||
@@ -351,7 +344,7 @@ def save_variables(save_path, variables=None, sess=None):
|
||||
|
||||
def load_variables(load_path, variables=None, sess=None):
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
variables = variables or tf.trainable_variables()
|
||||
|
||||
loaded_params = joblib.load(os.path.expanduser(load_path))
|
||||
restores = []
|
||||
|
@@ -32,11 +32,6 @@ class VecEnv(ABC):
|
||||
"""
|
||||
closed = False
|
||||
viewer = None
|
||||
|
||||
metadata = {
|
||||
'render.modes': ['human', 'rgb_array']
|
||||
}
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
|
@@ -20,6 +20,9 @@ class DummyVecEnv(VecEnv):
|
||||
env = self.envs[0]
|
||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||
obs_space = env.observation_space
|
||||
if isinstance(obs_space, spaces.MultiDiscrete):
|
||||
obs_space.shape = obs_space.shape[0]
|
||||
|
||||
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
||||
|
||||
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
|
||||
@@ -27,7 +30,6 @@ class DummyVecEnv(VecEnv):
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
self.specs = [e.spec for e in self.envs]
|
||||
|
||||
def step_async(self, actions):
|
||||
listify = True
|
||||
@@ -77,6 +79,6 @@ class DummyVecEnv(VecEnv):
|
||||
|
||||
def render(self, mode='human'):
|
||||
if self.num_envs == 1:
|
||||
return self.envs[0].render(mode=mode)
|
||||
self.envs[0].render(mode=mode)
|
||||
else:
|
||||
return super().render(mode=mode)
|
||||
super().render(mode=mode)
|
||||
|
@@ -54,7 +54,6 @@ class ShmemVecEnv(VecEnv):
|
||||
proc.start()
|
||||
child_pipe.close()
|
||||
self.waiting_step = False
|
||||
self.specs = [f().spec for f in env_fns]
|
||||
self.viewer = None
|
||||
|
||||
def reset(self):
|
||||
|
@@ -57,7 +57,6 @@ class SubprocVecEnv(VecEnv):
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
observation_space, action_space = self.remotes[0].recv()
|
||||
self.viewer = None
|
||||
self.specs = [f().spec for f in env_fns]
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||
|
||||
def step_async(self, actions):
|
||||
@@ -71,13 +70,13 @@ class SubprocVecEnv(VecEnv):
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
self.waiting = False
|
||||
obs, rews, dones, infos = zip(*results)
|
||||
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
|
||||
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def reset(self):
|
||||
self._assert_not_closed()
|
||||
for remote in self.remotes:
|
||||
remote.send(('reset', None))
|
||||
return _flatten_obs([remote.recv() for remote in self.remotes])
|
||||
return np.stack([remote.recv() for remote in self.remotes])
|
||||
|
||||
def close_extras(self):
|
||||
self.closed = True
|
||||
@@ -98,17 +97,3 @@ class SubprocVecEnv(VecEnv):
|
||||
|
||||
def _assert_not_closed(self):
|
||||
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
|
||||
|
||||
|
||||
def _flatten_obs(obs):
|
||||
assert isinstance(obs, list) or isinstance(obs, tuple)
|
||||
assert len(obs) > 0
|
||||
|
||||
if isinstance(obs[0], dict):
|
||||
import collections
|
||||
assert isinstance(obs, collections.OrderedDict)
|
||||
keys = obs[0].keys()
|
||||
return {k: np.stack([o[k] for o in obs]) for k in keys}
|
||||
else:
|
||||
return np.stack(obs)
|
||||
|
||||
|
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
Tests for asynchronous vectorized environments.
|
||||
"""
|
||||
|
||||
import gym
|
||||
import pytest
|
||||
import os
|
||||
import glob
|
||||
import tempfile
|
||||
|
||||
from .dummy_vec_env import DummyVecEnv
|
||||
from .shmem_vec_env import ShmemVecEnv
|
||||
from .subproc_vec_env import SubprocVecEnv
|
||||
from .vec_video_recorder import VecVideoRecorder
|
||||
|
||||
@pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv))
|
||||
@pytest.mark.parametrize('num_envs', (1, 4))
|
||||
@pytest.mark.parametrize('video_length', (10, 100))
|
||||
@pytest.mark.parametrize('video_interval', (1, 50))
|
||||
def test_video_recorder(klass, num_envs, video_length, video_interval):
|
||||
"""
|
||||
Wrap an existing VecEnv with VevVideoRecorder,
|
||||
Make (video_interval + video_length + 1) steps,
|
||||
then check that the file is present
|
||||
"""
|
||||
|
||||
def make_fn():
|
||||
env = gym.make('PongNoFrameskip-v4')
|
||||
return env
|
||||
fns = [make_fn for _ in range(num_envs)]
|
||||
env = klass(fns)
|
||||
|
||||
with tempfile.TemporaryDirectory() as video_path:
|
||||
env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length)
|
||||
|
||||
env.reset()
|
||||
for _ in range(video_interval + video_length + 1):
|
||||
env.step([0] * num_envs)
|
||||
env.close()
|
||||
|
||||
|
||||
recorded_video = glob.glob(os.path.join(video_path, "*.mp4"))
|
||||
|
||||
# first and second step
|
||||
assert len(recorded_video) == 2
|
||||
# Files are not empty
|
||||
assert all(os.stat(p).st_size != 0 for p in recorded_video)
|
||||
|
||||
|
@@ -1,89 +0,0 @@
|
||||
import os
|
||||
from baselines import logger
|
||||
from baselines.common.vec_env import VecEnvWrapper
|
||||
from gym.wrappers.monitoring import video_recorder
|
||||
|
||||
|
||||
class VecVideoRecorder(VecEnvWrapper):
|
||||
"""
|
||||
Wrap VecEnv to record rendered image as mp4 video.
|
||||
"""
|
||||
|
||||
def __init__(self, venv, directory, record_video_trigger, video_length=200):
|
||||
"""
|
||||
# Arguments
|
||||
venv: VecEnv to wrap
|
||||
directory: Where to save videos
|
||||
record_video_trigger:
|
||||
Function that defines when to start recording.
|
||||
The function takes the current number of step,
|
||||
and returns whether we should start recording or not.
|
||||
video_length: Length of recorded video
|
||||
"""
|
||||
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.record_video_trigger = record_video_trigger
|
||||
self.video_recorder = None
|
||||
|
||||
self.directory = os.path.abspath(directory)
|
||||
if not os.path.exists(self.directory): os.mkdir(self.directory)
|
||||
|
||||
self.file_prefix = "vecenv"
|
||||
self.file_infix = '{}'.format(os.getpid())
|
||||
self.step_id = 0
|
||||
self.video_length = video_length
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames = 0
|
||||
|
||||
def reset(self):
|
||||
obs = self.venv.reset()
|
||||
|
||||
self.start_video_recorder()
|
||||
|
||||
return obs
|
||||
|
||||
def start_video_recorder(self):
|
||||
self.close_video_recorder()
|
||||
|
||||
base_path = os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.step_id))
|
||||
self.video_recorder = video_recorder.VideoRecorder(
|
||||
env=self.venv,
|
||||
base_path=base_path,
|
||||
metadata={'step_id': self.step_id}
|
||||
)
|
||||
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames = 1
|
||||
self.recording = True
|
||||
|
||||
def _video_enabled(self):
|
||||
return self.record_video_trigger(self.step_id)
|
||||
|
||||
def step_wait(self):
|
||||
obs, rews, dones, infos = self.venv.step_wait()
|
||||
|
||||
self.step_id += 1
|
||||
if self.recording:
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames += 1
|
||||
if self.recorded_frames > self.video_length:
|
||||
logger.info("Saving video to ", self.video_recorder.path)
|
||||
self.close_video_recorder()
|
||||
elif self._video_enabled():
|
||||
self.start_video_recorder()
|
||||
|
||||
return obs, rews, dones, infos
|
||||
|
||||
def close_video_recorder(self):
|
||||
if self.recording:
|
||||
self.video_recorder.close()
|
||||
self.recording = False
|
||||
self.recorded_frames = 0
|
||||
|
||||
def close(self):
|
||||
VecEnvWrapper.close(self)
|
||||
self.close_video_recorder()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
@@ -66,6 +66,7 @@ def learn(network, env,
|
||||
|
||||
action_noise = None
|
||||
param_noise = None
|
||||
nb_actions = env.action_space.shape[-1]
|
||||
if noise_type is not None:
|
||||
for current_noise_type in noise_type.split(','):
|
||||
current_noise_type = current_noise_type.strip()
|
||||
|
@@ -67,6 +67,7 @@ class DDPG(object):
|
||||
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
|
||||
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
|
||||
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
|
||||
adaptive_param_noise=True, adaptive_param_noise_policy_threshold=.1,
|
||||
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
|
||||
# Inputs.
|
||||
self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
|
||||
@@ -185,7 +186,7 @@ class DDPG(object):
|
||||
normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
|
||||
self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
|
||||
if self.critic_l2_reg > 0.:
|
||||
critic_reg_vars = [var for var in self.critic.trainable_vars if var.name.endswith('/w:0') and 'output' not in var.name]
|
||||
critic_reg_vars = [var for var in self.critic.trainable_vars if 'kernel' in var.name and 'output' not in var.name]
|
||||
for var in critic_reg_vars:
|
||||
logger.info(' regularizing: {}'.format(var.name))
|
||||
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
|
||||
@@ -270,7 +271,7 @@ class DDPG(object):
|
||||
|
||||
if self.action_noise is not None and apply_noise:
|
||||
noise = self.action_noise()
|
||||
assert noise.shape == action[0].shape
|
||||
assert noise.shape == action.shape
|
||||
action += noise
|
||||
action = np.clip(action, self.action_range[0], self.action_range[1])
|
||||
|
||||
|
@@ -42,7 +42,7 @@ class Critic(Model):
|
||||
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
|
||||
x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated
|
||||
x = self.network_builder(x)
|
||||
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3), name='output')
|
||||
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
||||
return x
|
||||
|
||||
@property
|
||||
|
@@ -1,17 +0,0 @@
|
||||
from baselines.run import main as M
|
||||
|
||||
def _run(argstr):
|
||||
M(('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
|
||||
|
||||
def test_popart():
|
||||
_run('--normalize_returns=True --popart=True')
|
||||
|
||||
def test_noise_normal():
|
||||
_run('--noise_type=normal_0.1')
|
||||
|
||||
def test_noise_ou():
|
||||
_run('--noise_type=ou_0.1')
|
||||
|
||||
def test_noise_adaptive():
|
||||
_run('--noise_type=adaptive-param_0.2,normal_0.1')
|
||||
|
@@ -5,4 +5,4 @@ from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
def wrap_atari_dqn(env):
|
||||
from baselines.common.atari_wrappers import wrap_deepmind
|
||||
return wrap_deepmind(env, frame_stack=True, scale=False)
|
||||
return wrap_deepmind(env, frame_stack=True, scale=True)
|
||||
|
@@ -33,7 +33,7 @@ The functions in this file can are used to create the following functions:
|
||||
stochastic: bool
|
||||
if set to False all the actions are always deterministic (default False)
|
||||
update_eps_ph: float
|
||||
update epsilon to a new value, if negative no update happens
|
||||
update epsilon a new value, if negative not update happens
|
||||
(default: no update)
|
||||
reset_ph: bool
|
||||
reset the perturbed policy by sampling a new perturbation
|
||||
|
@@ -2,9 +2,9 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
|
||||
def _mlp(hiddens, input_, num_actions, scope, reuse=False, layer_norm=False):
|
||||
def _mlp(hiddens, inpt, num_actions, scope, reuse=False, layer_norm=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = input_
|
||||
out = inpt
|
||||
for hidden in hiddens:
|
||||
out = layers.fully_connected(out, num_outputs=hidden, activation_fn=None)
|
||||
if layer_norm:
|
||||
@@ -21,9 +21,6 @@ def mlp(hiddens=[], layer_norm=False):
|
||||
----------
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
layer_norm: bool
|
||||
if true applies layer normalization for every layer
|
||||
as described in https://arxiv.org/abs/1607.06450
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -33,9 +30,9 @@ def mlp(hiddens=[], layer_norm=False):
|
||||
return lambda *args, **kwargs: _mlp(hiddens, layer_norm=layer_norm, *args, **kwargs)
|
||||
|
||||
|
||||
def _cnn_to_mlp(convs, hiddens, dueling, input_, num_actions, scope, reuse=False, layer_norm=False):
|
||||
def _cnn_to_mlp(convs, hiddens, dueling, inpt, num_actions, scope, reuse=False, layer_norm=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = input_
|
||||
out = inpt
|
||||
with tf.variable_scope("convnet"):
|
||||
for num_outputs, kernel_size, stride in convs:
|
||||
out = layers.convolution2d(out,
|
||||
@@ -75,7 +72,7 @@ def cnn_to_mlp(convs, hiddens, dueling=False, layer_norm=False):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
convs: [(int, int, int)]
|
||||
convs: [(int, int int)]
|
||||
list of convolutional layers in form of
|
||||
(num_outputs, kernel_size, stride)
|
||||
hiddens: [int]
|
||||
@@ -83,9 +80,6 @@ def cnn_to_mlp(convs, hiddens, dueling=False, layer_norm=False):
|
||||
dueling: bool
|
||||
if true double the output MLP to compute a baseline
|
||||
for action scores
|
||||
layer_norm: bool
|
||||
if true applies layer normalization for every layer
|
||||
as described in https://arxiv.org/abs/1607.06450
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@@ -24,7 +24,7 @@ Hopper-v1, Walker2d-v1, HalfCheetah-v1, Humanoid-v1, HumanoidStandup-v1. Every i
|
||||
|
||||
For details (e.g., adversarial loss, discriminator accuracy, etc.) about GAIL training, please see [here](https://drive.google.com/drive/folders/1nnU8dqAV9i37-_5_vWIspyFUJFQLCsDD?usp=sharing)
|
||||
|
||||
### Determinstic Policy (Set std=0)
|
||||
### Determinstic Polciy (Set std=0)
|
||||
| | Un-normalized | Normalized |
|
||||
|---|---|---|
|
||||
| Hopper-v1 | <img src='Hopper-unnormalized-deterministic-scores.png'> | <img src='Hopper-normalized-deterministic-scores.png'> |
|
||||
|
@@ -6,29 +6,26 @@ For details on Hindsight Experience Replay (HER), please read the [paper](https:
|
||||
### Getting started
|
||||
Training an agent is very simple:
|
||||
```bash
|
||||
python -m baselines.run --alg=her --env=FetchReach-v1 --num_timesteps=5000
|
||||
python -m baselines.her.experiment.train
|
||||
```
|
||||
This will train a DDPG+HER agent on the `FetchReach` environment.
|
||||
You should see the success rate go up quickly to `1.0`, which means that the agent achieves the
|
||||
desired goal in 100% of the cases (note how HER can solve it in <5k steps - try doing that with PPO by replacing her with ppo2 :))
|
||||
The training script logs other diagnostics as well. Policy at the end of the training can be saved using `--save_path` flag, for instance:
|
||||
```bash
|
||||
python -m baselines.run --alg=her --env=FetchReach-v1 --num_timesteps=5000 --save_path=~/policies/her/fetchreach5k
|
||||
```
|
||||
desired goal in 100% of the cases.
|
||||
The training script logs other diagnostics as well and pickles the best policy so far (w.r.t. to its test success rate),
|
||||
the latest policy, and, if enabled, a history of policies every K epochs.
|
||||
|
||||
To inspect what the agent has learned, use the `--play` flag:
|
||||
To inspect what the agent has learned, use the play script:
|
||||
```bash
|
||||
python -m baselines.run --alg=her --env=FetchReach-v1 --num_timesteps=5000 --play
|
||||
python -m baselines.her.experiment.play /path/to/an/experiment/policy_best.pkl
|
||||
```
|
||||
(note `--play` can be combined with `--load_path`, which lets one load trained policies, for more results see [README.md](../../README.md))
|
||||
You can try it right now with the results of the training step (the script prints out the path for you).
|
||||
This should visualize the current policy for 10 episodes and will also print statistics.
|
||||
|
||||
|
||||
### Reproducing results
|
||||
In [Plappert et al. (2018)](https://arxiv.org/abs/1802.09464), 38 trajectories were generated in parallel
|
||||
(19 MPI processes, each generating computing gradients from 2 trajectories and aggregating).
|
||||
To reproduce that behaviour, use
|
||||
In order to reproduce the results from [Plappert et al. (2018)](https://arxiv.org/abs/1802.09464), run the following command:
|
||||
```bash
|
||||
mpirun -np 19 python -m baselines.run --num_env=2 --alg=her ...
|
||||
python -m baselines.her.experiment.train --num_cpu 19
|
||||
```
|
||||
This will require a machine with sufficient amount of physical CPU cores. In our experiments,
|
||||
we used [Azure's D15v2 instances](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/sizes),
|
||||
@@ -48,13 +45,6 @@ python experiment/data_generation/fetch_data_generation.py
|
||||
```
|
||||
This outputs ```data_fetch_random_100.npz``` file which is our data file.
|
||||
|
||||
To launch training with demonstrations (more technically, with behaviour cloning loss as an auxilliary loss), run the following
|
||||
```bash
|
||||
python -m baselines.run --alg=her --env=FetchPickAndPlace-v1 --num_timesteps=2.5e6 --demo_file=/Path/to/demo_file.npz
|
||||
```
|
||||
This will train a DDPG+HER agent on the `FetchPickAndPlace` environment by using previously generated demonstration data.
|
||||
To inspect what the agent has learned, use the `--play` flag as described above.
|
||||
|
||||
#### Configuration
|
||||
The provided configuration is for training an agent with HER without demonstrations, we need to change a few paramters for the HER algorithm to learn through demonstrations, to do that, set:
|
||||
|
||||
@@ -72,7 +62,13 @@ Apart from these changes the reported results also have the following configurat
|
||||
* random_eps: 0.1 - percentage of time a random action is taken
|
||||
* noise_eps: 0.1 - std of gaussian noise added to not-completely-random actions
|
||||
|
||||
These parameters can be changed either in [experiment/config.py](experiment/config.py) or passed to the command line as `--param=value`)
|
||||
Now training an agent with pre-recorded demonstrations:
|
||||
```bash
|
||||
python -m baselines.her.experiment.train --env=FetchPickAndPlace-v0 --n_epochs=1000 --demo_file=/Path/to/demo_file.npz --num_cpu=1
|
||||
```
|
||||
|
||||
This will train a DDPG+HER agent on the `FetchPickAndPlace` environment by using previously generated demonstration data.
|
||||
To inspect what the agent has learned, use the play script as described above.
|
||||
|
||||
### Results
|
||||
Training with demonstrations helps overcome the exploration problem and achieves a faster and better convergence. The following graphs contrast the difference between training with and without demonstration data, We report the mean Q values vs Epoch and the Success Rate vs Epoch:
|
||||
@@ -82,4 +78,3 @@ Training with demonstrations helps overcome the exploration problem and achieves
|
||||
<center><img src="../../data/fetchPickAndPlaceContrast.png"></center>
|
||||
<div class="thecap" align="middle"><b>Training results for Fetch Pick and Place task constrasting between training with and without demonstration data.</b></div>
|
||||
</div>
|
||||
|
||||
|
@@ -10,14 +10,13 @@ from baselines.her.util import (
|
||||
from baselines.her.normalizer import Normalizer
|
||||
from baselines.her.replay_buffer import ReplayBuffer
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from baselines.common import tf_util
|
||||
|
||||
|
||||
def dims_to_shapes(input_dims):
|
||||
return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()}
|
||||
|
||||
|
||||
global DEMO_BUFFER #buffer for demonstrations
|
||||
global demoBuffer #buffer for demonstrations
|
||||
|
||||
class DDPG(object):
|
||||
@store_args
|
||||
@@ -95,16 +94,16 @@ class DDPG(object):
|
||||
self._create_network(reuse=reuse)
|
||||
|
||||
# Configure the replay buffer.
|
||||
buffer_shapes = {key: (self.T-1 if key != 'o' else self.T, *input_shapes[key])
|
||||
buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
|
||||
for key, val in input_shapes.items()}
|
||||
buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
|
||||
buffer_shapes['ag'] = (self.T, self.dimg)
|
||||
buffer_shapes['ag'] = (self.T+1, self.dimg)
|
||||
|
||||
buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
|
||||
self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
|
||||
|
||||
global DEMO_BUFFER
|
||||
DEMO_BUFFER = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer
|
||||
global demoBuffer
|
||||
demoBuffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer
|
||||
|
||||
def _random_action(self, n):
|
||||
return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))
|
||||
@@ -120,11 +119,6 @@ class DDPG(object):
|
||||
g = np.clip(g, -self.clip_obs, self.clip_obs)
|
||||
return o, g
|
||||
|
||||
def step(self, obs):
|
||||
actions = self.get_actions(obs['observation'], obs['achieved_goal'], obs['desired_goal'])
|
||||
return actions, None, None, None
|
||||
|
||||
|
||||
def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
|
||||
compute_Q=False):
|
||||
o, g = self._preprocess_og(o, ag, g)
|
||||
@@ -157,30 +151,25 @@ class DDPG(object):
|
||||
else:
|
||||
return ret
|
||||
|
||||
def init_demo_buffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer
|
||||
def initDemoBuffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer
|
||||
|
||||
demoData = np.load(demoDataFile) #load the demonstration data from data file
|
||||
info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')]
|
||||
info_values = [np.empty((self.T - 1, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]
|
||||
|
||||
demo_data_obs = demoData['obs']
|
||||
demo_data_acs = demoData['acs']
|
||||
demo_data_info = demoData['info']
|
||||
info_values = [np.empty((self.T, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]
|
||||
|
||||
for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training
|
||||
obs, acts, goals, achieved_goals = [], [] ,[] ,[]
|
||||
i = 0
|
||||
for transition in range(self.T - 1):
|
||||
obs.append([demo_data_obs[epsd][transition].get('observation')])
|
||||
acts.append([demo_data_acs[epsd][transition]])
|
||||
goals.append([demo_data_obs[epsd][transition].get('desired_goal')])
|
||||
achieved_goals.append([demo_data_obs[epsd][transition].get('achieved_goal')])
|
||||
for transition in range(self.T):
|
||||
obs.append([demoData['obs'][epsd ][transition].get('observation')])
|
||||
acts.append([demoData['acs'][epsd][transition]])
|
||||
goals.append([demoData['obs'][epsd][transition].get('desired_goal')])
|
||||
achieved_goals.append([demoData['obs'][epsd][transition].get('achieved_goal')])
|
||||
for idx, key in enumerate(info_keys):
|
||||
info_values[idx][transition, i] = demo_data_info[epsd][transition][key]
|
||||
info_values[idx][transition, i] = demoData['info'][epsd][transition][key]
|
||||
|
||||
|
||||
obs.append([demo_data_obs[epsd][self.T - 1].get('observation')])
|
||||
achieved_goals.append([demo_data_obs[epsd][self.T - 1].get('achieved_goal')])
|
||||
obs.append([demoData['obs'][epsd][self.T].get('observation')])
|
||||
achieved_goals.append([demoData['obs'][epsd][self.T].get('achieved_goal')])
|
||||
|
||||
episode = dict(o=obs,
|
||||
u=acts,
|
||||
@@ -190,9 +179,10 @@ class DDPG(object):
|
||||
episode['info_{}'.format(key)] = value
|
||||
|
||||
episode = convert_episode_to_batch_major(episode)
|
||||
global DEMO_BUFFER
|
||||
DEMO_BUFFER.store_episode(episode) # create the observation dict and append them into the demonstration buffer
|
||||
logger.debug("Demo buffer size currently ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size
|
||||
global demoBuffer
|
||||
demoBuffer.store_episode(episode) # create the observation dict and append them into the demonstration buffer
|
||||
|
||||
print("Demo buffer size currently ", demoBuffer.get_current_size()) #print out the demonstration buffer size
|
||||
|
||||
if update_stats:
|
||||
# add transitions to normalizer to normalize the demo data as well
|
||||
@@ -201,7 +191,7 @@ class DDPG(object):
|
||||
num_normalizing_transitions = transitions_in_episode_batch(episode)
|
||||
transitions = self.sample_transitions(episode, num_normalizing_transitions)
|
||||
|
||||
o, g, ag = transitions['o'], transitions['g'], transitions['ag']
|
||||
o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
# No need to preprocess the o_2 and g_2 since this is only used for stats
|
||||
|
||||
@@ -212,8 +202,6 @@ class DDPG(object):
|
||||
self.g_stats.recompute_stats()
|
||||
episode.clear()
|
||||
|
||||
logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size
|
||||
|
||||
def store_episode(self, episode_batch, update_stats=True):
|
||||
"""
|
||||
episode_batch: array of batch_size x (T or T+1) x dim_key
|
||||
@@ -229,7 +217,7 @@ class DDPG(object):
|
||||
num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
|
||||
transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)
|
||||
|
||||
o, g, ag = transitions['o'], transitions['g'], transitions['ag']
|
||||
o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
# No need to preprocess the o_2 and g_2 since this is only used for stats
|
||||
|
||||
@@ -263,9 +251,9 @@ class DDPG(object):
|
||||
def sample_batch(self):
|
||||
if self.bc_loss: #use demonstration buffer to sample as well if bc_loss flag is set TRUE
|
||||
transitions = self.buffer.sample(self.batch_size - self.demo_batch_size)
|
||||
global DEMO_BUFFER
|
||||
transitions_demo = DEMO_BUFFER.sample(self.demo_batch_size) #sample from the demo buffer
|
||||
for k, values in transitions_demo.items():
|
||||
global demoBuffer
|
||||
transitionsDemo = demoBuffer.sample(self.demo_batch_size) #sample from the demo buffer
|
||||
for k, values in transitionsDemo.items():
|
||||
rolloutV = transitions[k].tolist()
|
||||
for v in values:
|
||||
rolloutV.append(v.tolist())
|
||||
@@ -314,7 +302,10 @@ class DDPG(object):
|
||||
|
||||
def _create_network(self, reuse=False):
|
||||
logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))
|
||||
self.sess = tf_util.get_session()
|
||||
|
||||
self.sess = tf.get_default_session()
|
||||
if self.sess is None:
|
||||
self.sess = tf.InteractiveSession()
|
||||
|
||||
# running averages
|
||||
with tf.variable_scope('o_stats') as vs:
|
||||
@@ -376,6 +367,8 @@ class DDPG(object):
|
||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
|
||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
|
||||
pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
|
||||
assert len(self._vars('main/Q')) == len(Q_grads_tf)
|
||||
@@ -410,7 +403,7 @@ class DDPG(object):
|
||||
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
|
||||
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
|
||||
|
||||
if prefix != '' and not prefix.endswith('/'):
|
||||
if prefix is not '' and not prefix.endswith('/'):
|
||||
return [(prefix + '/' + key, val) for key, val in logs]
|
||||
else:
|
||||
return logs
|
||||
@@ -442,7 +435,3 @@ class DDPG(object):
|
||||
assert(len(vars) == len(state["tf"]))
|
||||
node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
|
||||
self.sess.run(node)
|
||||
|
||||
def save(self, save_path):
|
||||
tf_util.save_variables(save_path)
|
||||
|
||||
|
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from baselines import logger
|
||||
from baselines.her.ddpg import DDPG
|
||||
from baselines.her.her_sampler import make_sample_her_transitions
|
||||
from baselines.bench.monitor import Monitor
|
||||
from baselines.her.her import make_sample_her_transitions
|
||||
|
||||
|
||||
DEFAULT_ENV_PARAMS = {
|
||||
'FetchReach-v1': {
|
||||
@@ -73,32 +72,16 @@ def cached_make_env(make_env):
|
||||
def prepare_params(kwargs):
|
||||
# DDPG params
|
||||
ddpg_params = dict()
|
||||
|
||||
env_name = kwargs['env_name']
|
||||
|
||||
def make_env(subrank=None):
|
||||
env = gym.make(env_name)
|
||||
if subrank is not None and logger.get_dir() is not None:
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank()
|
||||
except ImportError:
|
||||
MPI = None
|
||||
mpi_rank = 0
|
||||
logger.warn('Running with a single MPI process. This should work, but the results may differ from the ones publshed in Plappert et al.')
|
||||
|
||||
max_episode_steps = env._max_episode_steps
|
||||
env = Monitor(env,
|
||||
os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
# hack to re-expose _max_episode_steps (ideally should replace reliance on it downstream)
|
||||
env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
|
||||
return env
|
||||
|
||||
def make_env():
|
||||
return gym.make(env_name)
|
||||
kwargs['make_env'] = make_env
|
||||
tmp_env = cached_make_env(kwargs['make_env'])
|
||||
assert hasattr(tmp_env, '_max_episode_steps')
|
||||
kwargs['T'] = tmp_env._max_episode_steps
|
||||
|
||||
tmp_env.reset()
|
||||
kwargs['max_u'] = np.array(kwargs['max_u']) if isinstance(kwargs['max_u'], list) else kwargs['max_u']
|
||||
kwargs['gamma'] = 1. - 1. / kwargs['T']
|
||||
if 'lr' in kwargs:
|
||||
|
@@ -1,5 +1,18 @@
|
||||
import gym
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import rospy
|
||||
import roslaunch
|
||||
|
||||
from random import randint
|
||||
from std_srvs.srv import Empty
|
||||
from sensor_msgs.msg import JointState
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from geometry_msgs.msg import Pose
|
||||
from std_msgs.msg import Float64
|
||||
from controller_manager_msgs.srv import SwitchController
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
"""Data generation for the case of a single block pick and place in Fetch Env"""
|
||||
@@ -9,7 +22,7 @@ observations = []
|
||||
infos = []
|
||||
|
||||
def main():
|
||||
env = gym.make('FetchPickAndPlace-v1')
|
||||
env = gym.make('FetchPickAndPlace-v0')
|
||||
numItr = 100
|
||||
initStateSpace = "random"
|
||||
env.reset()
|
||||
@@ -18,19 +31,21 @@ def main():
|
||||
obs = env.reset()
|
||||
print("ITERATION NUMBER ", len(actions))
|
||||
goToGoal(env, obs)
|
||||
|
||||
|
||||
|
||||
fileName = "data_fetch"
|
||||
fileName += "_" + initStateSpace
|
||||
fileName += "_" + str(numItr)
|
||||
fileName += ".npz"
|
||||
|
||||
|
||||
np.savez_compressed(fileName, acs=actions, obs=observations, info=infos) # save the file
|
||||
|
||||
def goToGoal(env, lastObs):
|
||||
|
||||
goal = lastObs['desired_goal']
|
||||
objectPos = lastObs['observation'][3:6]
|
||||
gripperPos = lastObs['observation'][:3]
|
||||
gripperState = lastObs['observation'][9:11]
|
||||
object_rel_pos = lastObs['observation'][6:9]
|
||||
episodeAcs = []
|
||||
episodeObs = []
|
||||
@@ -38,7 +53,7 @@ def goToGoal(env, lastObs):
|
||||
|
||||
object_oriented_goal = object_rel_pos.copy()
|
||||
object_oriented_goal[2] += 0.03 # first make the gripper go slightly above the object
|
||||
|
||||
|
||||
timeStep = 0 #count the total number of timesteps
|
||||
episodeObs.append(lastObs)
|
||||
|
||||
@@ -61,6 +76,8 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
while np.linalg.norm(object_rel_pos) >= 0.005 and timeStep <= env._max_episode_steps :
|
||||
@@ -79,6 +96,8 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
|
||||
@@ -98,6 +117,8 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
while True: #limit the number of timesteps in the episode to a fixed duration
|
||||
@@ -113,6 +134,8 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
if timeStep >= env._max_episode_steps: break
|
||||
|
@@ -1,4 +1,3 @@
|
||||
# DEPRECATED, use --play flag to baselines.run instead
|
||||
import click
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
@@ -1,5 +1,3 @@
|
||||
# DEPRECATED, use baselines.common.plot_util instead
|
||||
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
194
baselines/her/experiment/train.py
Normal file
194
baselines/her/experiment/train.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import json
|
||||
from mpi4py import MPI
|
||||
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines.common.mpi_moments import mpi_moments
|
||||
import baselines.her.experiment.config as config
|
||||
from baselines.her.rollout import RolloutWorker
|
||||
from baselines.her.util import mpi_fork
|
||||
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
|
||||
def mpi_average(value):
|
||||
if value == []:
|
||||
value = [0.]
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
return mpi_moments(np.array(value))[0]
|
||||
|
||||
|
||||
def train(policy, rollout_worker, evaluator,
|
||||
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
|
||||
save_policies, demo_file, **kwargs):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
|
||||
best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl')
|
||||
periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl')
|
||||
|
||||
logger.info("Training...")
|
||||
best_success_rate = -1
|
||||
|
||||
if policy.bc_loss == 1: policy.initDemoBuffer(demo_file) #initialize demo buffer if training with demonstrations
|
||||
for epoch in range(n_epochs):
|
||||
# train
|
||||
rollout_worker.clear_history()
|
||||
for _ in range(n_cycles):
|
||||
episode = rollout_worker.generate_rollouts()
|
||||
policy.store_episode(episode)
|
||||
for _ in range(n_batches):
|
||||
policy.train()
|
||||
policy.update_target_net()
|
||||
|
||||
# test
|
||||
evaluator.clear_history()
|
||||
for _ in range(n_test_rollouts):
|
||||
evaluator.generate_rollouts()
|
||||
|
||||
# record logs
|
||||
logger.record_tabular('epoch', epoch)
|
||||
for key, val in evaluator.logs('test'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in rollout_worker.logs('train'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in policy.logs():
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
|
||||
if rank == 0:
|
||||
logger.dump_tabular()
|
||||
|
||||
# save the policy if it's better than the previous ones
|
||||
success_rate = mpi_average(evaluator.current_success_rate())
|
||||
if rank == 0 and success_rate >= best_success_rate and save_policies:
|
||||
best_success_rate = success_rate
|
||||
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
|
||||
evaluator.save_policy(best_policy_path)
|
||||
evaluator.save_policy(latest_policy_path)
|
||||
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies:
|
||||
policy_path = periodic_policy_path.format(epoch)
|
||||
logger.info('Saving periodic policy to {} ...'.format(policy_path))
|
||||
evaluator.save_policy(policy_path)
|
||||
|
||||
# make sure that different threads have different seeds
|
||||
local_uniform = np.random.uniform(size=(1,))
|
||||
root_uniform = local_uniform.copy()
|
||||
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
|
||||
if rank != 0:
|
||||
assert local_uniform[0] != root_uniform[0]
|
||||
|
||||
|
||||
def launch(
|
||||
env, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return,
|
||||
demo_file, override_params={}, save_policies=True
|
||||
):
|
||||
# Fork for multi-CPU MPI implementation.
|
||||
if num_cpu > 1:
|
||||
try:
|
||||
whoami = mpi_fork(num_cpu, ['--bind-to', 'core'])
|
||||
except CalledProcessError:
|
||||
# fancy version of mpi call failed, try simple version
|
||||
whoami = mpi_fork(num_cpu)
|
||||
|
||||
if whoami == 'parent':
|
||||
sys.exit(0)
|
||||
import baselines.common.tf_util as U
|
||||
U.single_threaded_session().__enter__()
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
# Configure logging
|
||||
if rank == 0:
|
||||
if logdir or logger.get_dir() is None:
|
||||
logger.configure(dir=logdir)
|
||||
else:
|
||||
logger.configure()
|
||||
logdir = logger.get_dir()
|
||||
assert logdir is not None
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Seed everything.
|
||||
rank_seed = seed + 1000000 * rank
|
||||
set_global_seeds(rank_seed)
|
||||
|
||||
# Prepare params.
|
||||
params = config.DEFAULT_PARAMS
|
||||
params['env_name'] = env
|
||||
params['replay_strategy'] = replay_strategy
|
||||
if env in config.DEFAULT_ENV_PARAMS:
|
||||
params.update(config.DEFAULT_ENV_PARAMS[env]) # merge env-specific parameters in
|
||||
params.update(**override_params) # makes it possible to override any parameter
|
||||
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
|
||||
json.dump(params, f)
|
||||
params = config.prepare_params(params)
|
||||
config.log_params(params, logger=logger)
|
||||
|
||||
if num_cpu == 1:
|
||||
logger.warn()
|
||||
logger.warn('*** Warning ***')
|
||||
logger.warn(
|
||||
'You are running HER with just a single MPI worker. This will work, but the ' +
|
||||
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
|
||||
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
|
||||
'are looking to reproduce those results, be aware of this. Please also refer to ' +
|
||||
'https://github.com/openai/baselines/issues/314 for further details.')
|
||||
logger.warn('****************')
|
||||
logger.warn()
|
||||
|
||||
dims = config.configure_dims(params)
|
||||
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
|
||||
|
||||
rollout_params = {
|
||||
'exploit': False,
|
||||
'use_target_net': False,
|
||||
'use_demo_states': True,
|
||||
'compute_Q': False,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
eval_params = {
|
||||
'exploit': True,
|
||||
'use_target_net': params['test_with_polyak'],
|
||||
'use_demo_states': False,
|
||||
'compute_Q': True,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
|
||||
rollout_params[name] = params[name]
|
||||
eval_params[name] = params[name]
|
||||
|
||||
rollout_worker = RolloutWorker(params['make_env'], policy, dims, logger, **rollout_params)
|
||||
rollout_worker.seed(rank_seed)
|
||||
|
||||
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
|
||||
evaluator.seed(rank_seed)
|
||||
|
||||
train(
|
||||
logdir=logdir, policy=policy, rollout_worker=rollout_worker,
|
||||
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
|
||||
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
|
||||
policy_save_interval=policy_save_interval, save_policies=save_policies, demo_file=demo_file)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on')
|
||||
@click.option('--logdir', type=str, default=None, help='the path to where logs and policy pickles should go. If not specified, creates a folder in /tmp/')
|
||||
@click.option('--n_epochs', type=int, default=50, help='the number of training epochs to run')
|
||||
@click.option('--num_cpu', type=int, default=1, help='the number of CPU cores to use (using MPI)')
|
||||
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
|
||||
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
|
||||
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
|
||||
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
|
||||
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
|
||||
def main(**kwargs):
|
||||
launch(**kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -1,193 +1,63 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import json
|
||||
from mpi4py import MPI
|
||||
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds, tf_util
|
||||
from baselines.common.mpi_moments import mpi_moments
|
||||
import baselines.her.experiment.config as config
|
||||
from baselines.her.rollout import RolloutWorker
|
||||
|
||||
def mpi_average(value):
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
if not any(value):
|
||||
value = [0.]
|
||||
return mpi_moments(np.array(value))[0]
|
||||
|
||||
|
||||
def train(*, policy, rollout_worker, evaluator,
|
||||
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
|
||||
save_path, demo_file, **kwargs):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
def make_sample_her_transitions(replay_strategy, replay_k, reward_fun):
|
||||
"""Creates a sample function that can be used for HER experience replay.
|
||||
|
||||
if save_path:
|
||||
latest_policy_path = os.path.join(save_path, 'policy_latest.pkl')
|
||||
best_policy_path = os.path.join(save_path, 'policy_best.pkl')
|
||||
periodic_policy_path = os.path.join(save_path, 'policy_{}.pkl')
|
||||
Args:
|
||||
replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none',
|
||||
regular DDPG experience replay is used
|
||||
replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times
|
||||
as many HER replays as regular replays are used)
|
||||
reward_fun (function): function to re-compute the reward with substituted goals
|
||||
"""
|
||||
if replay_strategy == 'future':
|
||||
future_p = 1 - (1. / (1 + replay_k))
|
||||
else: # 'replay_strategy' == 'none'
|
||||
future_p = 0
|
||||
|
||||
logger.info("Training...")
|
||||
best_success_rate = -1
|
||||
def _sample_her_transitions(episode_batch, batch_size_in_transitions):
|
||||
"""episode_batch is {key: array(buffer_size x T x dim_key)}
|
||||
"""
|
||||
T = episode_batch['u'].shape[1]
|
||||
rollout_batch_size = episode_batch['u'].shape[0]
|
||||
batch_size = batch_size_in_transitions
|
||||
|
||||
if policy.bc_loss == 1: policy.init_demo_buffer(demo_file) #initialize demo buffer if training with demonstrations
|
||||
# Select which episodes and time steps to use.
|
||||
episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
|
||||
t_samples = np.random.randint(T, size=batch_size)
|
||||
transitions = {key: episode_batch[key][episode_idxs, t_samples].copy()
|
||||
for key in episode_batch.keys()}
|
||||
|
||||
# num_timesteps = n_epochs * n_cycles * rollout_length * number of rollout workers
|
||||
for epoch in range(n_epochs):
|
||||
# train
|
||||
rollout_worker.clear_history()
|
||||
for _ in range(n_cycles):
|
||||
episode = rollout_worker.generate_rollouts()
|
||||
policy.store_episode(episode)
|
||||
for _ in range(n_batches):
|
||||
policy.train()
|
||||
policy.update_target_net()
|
||||
# Select future time indexes proportional with probability future_p. These
|
||||
# will be used for HER replay by substituting in future goals.
|
||||
her_indexes = np.where(np.random.uniform(size=batch_size) < future_p)
|
||||
future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
|
||||
future_offset = future_offset.astype(int)
|
||||
future_t = (t_samples + 1 + future_offset)[her_indexes]
|
||||
|
||||
# test
|
||||
evaluator.clear_history()
|
||||
for _ in range(n_test_rollouts):
|
||||
evaluator.generate_rollouts()
|
||||
# Replace goal with achieved goal but only for the previously-selected
|
||||
# HER transitions (as defined by her_indexes). For the other transitions,
|
||||
# keep the original goal.
|
||||
future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
|
||||
transitions['g'][her_indexes] = future_ag
|
||||
|
||||
# record logs
|
||||
logger.record_tabular('epoch', epoch)
|
||||
for key, val in evaluator.logs('test'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in rollout_worker.logs('train'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in policy.logs():
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
# Reconstruct info dictionary for reward computation.
|
||||
info = {}
|
||||
for key, value in transitions.items():
|
||||
if key.startswith('info_'):
|
||||
info[key.replace('info_', '')] = value
|
||||
|
||||
if rank == 0:
|
||||
logger.dump_tabular()
|
||||
# Re-compute reward since we may have substituted the goal.
|
||||
reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
|
||||
reward_params['info'] = info
|
||||
transitions['r'] = reward_fun(**reward_params)
|
||||
|
||||
# save the policy if it's better than the previous ones
|
||||
success_rate = mpi_average(evaluator.current_success_rate())
|
||||
if rank == 0 and success_rate >= best_success_rate and save_path:
|
||||
best_success_rate = success_rate
|
||||
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
|
||||
evaluator.save_policy(best_policy_path)
|
||||
evaluator.save_policy(latest_policy_path)
|
||||
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_path:
|
||||
policy_path = periodic_policy_path.format(epoch)
|
||||
logger.info('Saving periodic policy to {} ...'.format(policy_path))
|
||||
evaluator.save_policy(policy_path)
|
||||
transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
|
||||
for k in transitions.keys()}
|
||||
|
||||
# make sure that different threads have different seeds
|
||||
local_uniform = np.random.uniform(size=(1,))
|
||||
root_uniform = local_uniform.copy()
|
||||
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
|
||||
if rank != 0:
|
||||
assert local_uniform[0] != root_uniform[0]
|
||||
assert(transitions['u'].shape[0] == batch_size_in_transitions)
|
||||
|
||||
return policy
|
||||
return transitions
|
||||
|
||||
|
||||
def learn(*, network, env, total_timesteps,
|
||||
seed=None,
|
||||
eval_env=None,
|
||||
replay_strategy='future',
|
||||
policy_save_interval=5,
|
||||
clip_return=True,
|
||||
demo_file=None,
|
||||
override_params=None,
|
||||
load_path=None,
|
||||
save_path=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
override_params = override_params or {}
|
||||
if MPI is not None:
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
num_cpu = MPI.COMM_WORLD.Get_size()
|
||||
|
||||
# Seed everything.
|
||||
rank_seed = seed + 1000000 * rank if seed is not None else None
|
||||
set_global_seeds(rank_seed)
|
||||
|
||||
# Prepare params.
|
||||
params = config.DEFAULT_PARAMS
|
||||
env_name = env.specs[0].id
|
||||
params['env_name'] = env_name
|
||||
params['replay_strategy'] = replay_strategy
|
||||
if env_name in config.DEFAULT_ENV_PARAMS:
|
||||
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
|
||||
params.update(**override_params) # makes it possible to override any parameter
|
||||
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
|
||||
json.dump(params, f)
|
||||
params = config.prepare_params(params)
|
||||
params['rollout_batch_size'] = env.num_envs
|
||||
|
||||
if demo_file is not None:
|
||||
params['bc_loss'] = 1
|
||||
params.update(kwargs)
|
||||
|
||||
config.log_params(params, logger=logger)
|
||||
|
||||
if num_cpu == 1:
|
||||
logger.warn()
|
||||
logger.warn('*** Warning ***')
|
||||
logger.warn(
|
||||
'You are running HER with just a single MPI worker. This will work, but the ' +
|
||||
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
|
||||
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
|
||||
'are looking to reproduce those results, be aware of this. Please also refer to ' +
|
||||
'https://github.com/openai/baselines/issues/314 for further details.')
|
||||
logger.warn('****************')
|
||||
logger.warn()
|
||||
|
||||
dims = config.configure_dims(params)
|
||||
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
|
||||
if load_path is not None:
|
||||
tf_util.load_variables(load_path)
|
||||
|
||||
rollout_params = {
|
||||
'exploit': False,
|
||||
'use_target_net': False,
|
||||
'use_demo_states': True,
|
||||
'compute_Q': False,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
eval_params = {
|
||||
'exploit': True,
|
||||
'use_target_net': params['test_with_polyak'],
|
||||
'use_demo_states': False,
|
||||
'compute_Q': True,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
|
||||
rollout_params[name] = params[name]
|
||||
eval_params[name] = params[name]
|
||||
|
||||
eval_env = eval_env or env
|
||||
|
||||
rollout_worker = RolloutWorker(env, policy, dims, logger, monitor=True, **rollout_params)
|
||||
evaluator = RolloutWorker(eval_env, policy, dims, logger, **eval_params)
|
||||
|
||||
n_cycles = params['n_cycles']
|
||||
n_epochs = total_timesteps // n_cycles // rollout_worker.T // rollout_worker.rollout_batch_size
|
||||
|
||||
return train(
|
||||
save_path=save_path, policy=policy, rollout_worker=rollout_worker,
|
||||
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
|
||||
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
|
||||
policy_save_interval=policy_save_interval, demo_file=demo_file)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on')
|
||||
@click.option('--total_timesteps', type=int, default=int(5e5), help='the number of timesteps to run')
|
||||
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
|
||||
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
|
||||
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
|
||||
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
|
||||
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
|
||||
def main(**kwargs):
|
||||
learn(**kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
return _sample_her_transitions
|
||||
|
@@ -1,63 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_sample_her_transitions(replay_strategy, replay_k, reward_fun):
|
||||
"""Creates a sample function that can be used for HER experience replay.
|
||||
|
||||
Args:
|
||||
replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none',
|
||||
regular DDPG experience replay is used
|
||||
replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times
|
||||
as many HER replays as regular replays are used)
|
||||
reward_fun (function): function to re-compute the reward with substituted goals
|
||||
"""
|
||||
if replay_strategy == 'future':
|
||||
future_p = 1 - (1. / (1 + replay_k))
|
||||
else: # 'replay_strategy' == 'none'
|
||||
future_p = 0
|
||||
|
||||
def _sample_her_transitions(episode_batch, batch_size_in_transitions):
|
||||
"""episode_batch is {key: array(buffer_size x T x dim_key)}
|
||||
"""
|
||||
T = episode_batch['u'].shape[1]
|
||||
rollout_batch_size = episode_batch['u'].shape[0]
|
||||
batch_size = batch_size_in_transitions
|
||||
|
||||
# Select which episodes and time steps to use.
|
||||
episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
|
||||
t_samples = np.random.randint(T, size=batch_size)
|
||||
transitions = {key: episode_batch[key][episode_idxs, t_samples].copy()
|
||||
for key in episode_batch.keys()}
|
||||
|
||||
# Select future time indexes proportional with probability future_p. These
|
||||
# will be used for HER replay by substituting in future goals.
|
||||
her_indexes = np.where(np.random.uniform(size=batch_size) < future_p)
|
||||
future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
|
||||
future_offset = future_offset.astype(int)
|
||||
future_t = (t_samples + 1 + future_offset)[her_indexes]
|
||||
|
||||
# Replace goal with achieved goal but only for the previously-selected
|
||||
# HER transitions (as defined by her_indexes). For the other transitions,
|
||||
# keep the original goal.
|
||||
future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
|
||||
transitions['g'][her_indexes] = future_ag
|
||||
|
||||
# Reconstruct info dictionary for reward computation.
|
||||
info = {}
|
||||
for key, value in transitions.items():
|
||||
if key.startswith('info_'):
|
||||
info[key.replace('info_', '')] = value
|
||||
|
||||
# Re-compute reward since we may have substituted the goal.
|
||||
reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
|
||||
reward_params['info'] = info
|
||||
transitions['r'] = reward_fun(**reward_params)
|
||||
|
||||
transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
|
||||
for k in transitions.keys()}
|
||||
|
||||
assert(transitions['u'].shape[0] == batch_size_in_transitions)
|
||||
|
||||
return transitions
|
||||
|
||||
return _sample_her_transitions
|
@@ -2,6 +2,7 @@ from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from mujoco_py import MujocoException
|
||||
|
||||
from baselines.her.util import convert_episode_to_batch_major, store_args
|
||||
|
||||
@@ -9,9 +10,9 @@ from baselines.her.util import convert_episode_to_batch_major, store_args
|
||||
class RolloutWorker:
|
||||
|
||||
@store_args
|
||||
def __init__(self, venv, policy, dims, logger, T, rollout_batch_size=1,
|
||||
def __init__(self, make_env, policy, dims, logger, T, rollout_batch_size=1,
|
||||
exploit=False, use_target_net=False, compute_Q=False, noise_eps=0,
|
||||
random_eps=0, history_len=100, render=False, monitor=False, **kwargs):
|
||||
random_eps=0, history_len=100, render=False, **kwargs):
|
||||
"""Rollout worker generates experience by interacting with one or many environments.
|
||||
|
||||
Args:
|
||||
@@ -30,7 +31,7 @@ class RolloutWorker:
|
||||
history_len (int): length of history for statistics smoothing
|
||||
render (boolean): whether or not to render the rollouts
|
||||
"""
|
||||
|
||||
self.envs = [make_env() for _ in range(rollout_batch_size)]
|
||||
assert self.T > 0
|
||||
|
||||
self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')]
|
||||
@@ -39,14 +40,26 @@ class RolloutWorker:
|
||||
self.Q_history = deque(maxlen=history_len)
|
||||
|
||||
self.n_episodes = 0
|
||||
self.g = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals
|
||||
self.initial_o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
|
||||
self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
|
||||
self.reset_all_rollouts()
|
||||
self.clear_history()
|
||||
|
||||
def reset_rollout(self, i):
|
||||
"""Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o`
|
||||
and `g` arrays accordingly.
|
||||
"""
|
||||
obs = self.envs[i].reset()
|
||||
self.initial_o[i] = obs['observation']
|
||||
self.initial_ag[i] = obs['achieved_goal']
|
||||
self.g[i] = obs['desired_goal']
|
||||
|
||||
def reset_all_rollouts(self):
|
||||
self.obs_dict = self.venv.reset()
|
||||
self.initial_o = self.obs_dict['observation']
|
||||
self.initial_ag = self.obs_dict['achieved_goal']
|
||||
self.g = self.obs_dict['desired_goal']
|
||||
"""Resets all `rollout_batch_size` rollout workers.
|
||||
"""
|
||||
for i in range(self.rollout_batch_size):
|
||||
self.reset_rollout(i)
|
||||
|
||||
def generate_rollouts(self):
|
||||
"""Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current
|
||||
@@ -62,8 +75,7 @@ class RolloutWorker:
|
||||
|
||||
# generate episodes
|
||||
obs, achieved_goals, acts, goals, successes = [], [], [], [], []
|
||||
dones = []
|
||||
info_values = [np.empty((self.T - 1, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
|
||||
info_values = [np.empty((self.T, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
|
||||
Qs = []
|
||||
for t in range(self.T):
|
||||
policy_output = self.policy.get_actions(
|
||||
@@ -87,27 +99,27 @@ class RolloutWorker:
|
||||
ag_new = np.empty((self.rollout_batch_size, self.dims['g']))
|
||||
success = np.zeros(self.rollout_batch_size)
|
||||
# compute new states and observations
|
||||
obs_dict_new, _, done, info = self.venv.step(u)
|
||||
o_new = obs_dict_new['observation']
|
||||
ag_new = obs_dict_new['achieved_goal']
|
||||
success = np.array([i.get('is_success', 0.0) for i in info])
|
||||
|
||||
if any(done):
|
||||
# here we assume all environments are done is ~same number of steps, so we terminate rollouts whenever any of the envs returns done
|
||||
# trick with using vecenvs is not to add the obs from the environments that are "done", because those are already observations
|
||||
# after a reset
|
||||
break
|
||||
|
||||
for i, info_dict in enumerate(info):
|
||||
for idx, key in enumerate(self.info_keys):
|
||||
info_values[idx][t, i] = info[i][key]
|
||||
for i in range(self.rollout_batch_size):
|
||||
try:
|
||||
# We fully ignore the reward here because it will have to be re-computed
|
||||
# for HER.
|
||||
curr_o_new, _, _, info = self.envs[i].step(u[i])
|
||||
if 'is_success' in info:
|
||||
success[i] = info['is_success']
|
||||
o_new[i] = curr_o_new['observation']
|
||||
ag_new[i] = curr_o_new['achieved_goal']
|
||||
for idx, key in enumerate(self.info_keys):
|
||||
info_values[idx][t, i] = info[key]
|
||||
if self.render:
|
||||
self.envs[i].render()
|
||||
except MujocoException as e:
|
||||
return self.generate_rollouts()
|
||||
|
||||
if np.isnan(o_new).any():
|
||||
self.logger.warn('NaN caught during rollout generation. Trying again...')
|
||||
self.reset_all_rollouts()
|
||||
return self.generate_rollouts()
|
||||
|
||||
dones.append(done)
|
||||
obs.append(o.copy())
|
||||
achieved_goals.append(ag.copy())
|
||||
successes.append(success.copy())
|
||||
@@ -117,6 +129,7 @@ class RolloutWorker:
|
||||
ag[...] = ag_new
|
||||
obs.append(o.copy())
|
||||
achieved_goals.append(ag.copy())
|
||||
self.initial_o[:] = o
|
||||
|
||||
episode = dict(o=obs,
|
||||
u=acts,
|
||||
@@ -163,8 +176,13 @@ class RolloutWorker:
|
||||
logs += [('mean_Q', np.mean(self.Q_history))]
|
||||
logs += [('episode', self.n_episodes)]
|
||||
|
||||
if prefix != '' and not prefix.endswith('/'):
|
||||
if prefix is not '' and not prefix.endswith('/'):
|
||||
return [(prefix + '/' + key, val) for key, val in logs]
|
||||
else:
|
||||
return logs
|
||||
|
||||
def seed(self, seed):
|
||||
"""Seeds each environment with a distinct seed derived from the passed in global seed.
|
||||
"""
|
||||
for idx, env in enumerate(self.envs):
|
||||
env.seed(seed + 1000 * idx)
|
||||
|
@@ -54,7 +54,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
||||
# Write out the data
|
||||
dashes = '-' * (keywidth + valwidth + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
|
||||
for (key, val) in sorted(key2str.items()):
|
||||
lines.append('| %s%s | %s%s |' % (
|
||||
key,
|
||||
' ' * (keywidth - len(key)),
|
||||
|
@@ -97,7 +97,7 @@ def learn(env, policy_fn, *,
|
||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||
|
||||
lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
|
||||
clip_param = clip_param * lrmult # Annealed clipping parameter epsilon
|
||||
clip_param = clip_param * lrmult # Annealed cliping parameter epislon
|
||||
|
||||
ob = U.get_placeholder_cached(name="ob")
|
||||
ac = pi.pdtype.sample_placeholder([None])
|
||||
|
@@ -20,6 +20,3 @@ def atari():
|
||||
lr=lambda f : f * 2.5e-4,
|
||||
cliprange=lambda f : f * 0.1,
|
||||
)
|
||||
|
||||
def retro():
|
||||
return atari()
|
||||
|
@@ -1,76 +0,0 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from baselines.ppo2.model import Model
|
||||
|
||||
class MicrobatchedModel(Model):
|
||||
"""
|
||||
Model that does training one microbatch at a time - when gradient computation
|
||||
on the entire minibatch causes some overflow
|
||||
"""
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size):
|
||||
|
||||
self.nmicrobatches = nbatch_train // microbatch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
assert nbatch_train % microbatch_size == 0, 'microbatch_size ({}) should divide nbatch_train ({}) evenly'.format(microbatch_size, nbatch_train)
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
ob_space=ob_space,
|
||||
ac_space=ac_space,
|
||||
nbatch_act=nbatch_act,
|
||||
nbatch_train=microbatch_size,
|
||||
nsteps=nsteps,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
|
||||
self.grads_ph = [tf.placeholder(dtype=g.dtype, shape=g.shape) for g in self.grads]
|
||||
grads_ph_and_vars = list(zip(self.grads_ph, self.var))
|
||||
self._apply_gradients_op = self.trainer.apply_gradients(grads_ph_and_vars)
|
||||
|
||||
|
||||
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
assert states is None, "microbatches with recurrent models are not supported yet"
|
||||
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
|
||||
# Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list)
|
||||
stats_vs = []
|
||||
|
||||
for microbatch_idx in range(self.nmicrobatches):
|
||||
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx+1) * self.microbatch_size)
|
||||
td_map = {
|
||||
self.train_model.X: obs[_sli],
|
||||
self.A:actions[_sli],
|
||||
self.ADV:advs[_sli],
|
||||
self.R:returns[_sli],
|
||||
self.CLIPRANGE:cliprange,
|
||||
self.OLDNEGLOGPAC:neglogpacs[_sli],
|
||||
self.OLDVPRED:values[_sli]
|
||||
}
|
||||
|
||||
# Compute gradient on a microbatch (note that variables do not change here) ...
|
||||
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)
|
||||
if microbatch_idx == 0:
|
||||
sum_grad_v = grad_v
|
||||
else:
|
||||
# .. and add to the total of the gradients
|
||||
for i, g in enumerate(grad_v):
|
||||
sum_grad_v[i] += g
|
||||
stats_vs.append(stats_v)
|
||||
|
||||
feed_dict = {ph: sum_g / self.nmicrobatches for ph, sum_g in zip(self.grads_ph, sum_grad_v)}
|
||||
feed_dict[self.LR] = lr
|
||||
# Update variables using average of the gradients
|
||||
self.sess.run(self._apply_gradients_op, feed_dict)
|
||||
# Return average of the stats
|
||||
return np.mean(np.array(stats_vs), axis=0).tolist()
|
||||
|
||||
|
||||
|
@@ -1,156 +0,0 @@
|
||||
import tensorflow as tf
|
||||
import functools
|
||||
|
||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
try:
|
||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||
from mpi4py import MPI
|
||||
from baselines.common.mpi_util import sync_from_root
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
We use this object to :
|
||||
__init__:
|
||||
- Creates the step_model
|
||||
- Creates the train_model
|
||||
|
||||
train():
|
||||
- Make the training part (feedforward and retropropagation of gradients)
|
||||
|
||||
save/load():
|
||||
- Save load the model
|
||||
"""
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size=None):
|
||||
self.sess = sess = get_session()
|
||||
|
||||
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
|
||||
# CREATE OUR TWO MODELS
|
||||
# act_model that is used for sampling
|
||||
act_model = policy(nbatch_act, 1, sess)
|
||||
|
||||
# Train model for training
|
||||
if microbatch_size is None:
|
||||
train_model = policy(nbatch_train, nsteps, sess)
|
||||
else:
|
||||
train_model = policy(microbatch_size, nsteps, sess)
|
||||
|
||||
# CREATE THE PLACEHOLDERS
|
||||
self.A = A = train_model.pdtype.sample_placeholder([None])
|
||||
self.ADV = ADV = tf.placeholder(tf.float32, [None])
|
||||
self.R = R = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old actor
|
||||
self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old critic
|
||||
self.OLDVPRED = OLDVPRED = tf.placeholder(tf.float32, [None])
|
||||
self.LR = LR = tf.placeholder(tf.float32, [])
|
||||
# Cliprange
|
||||
self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, [])
|
||||
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
|
||||
# Calculate the entropy
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
|
||||
# CALCULATE THE LOSS
|
||||
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
||||
|
||||
# Clip the value to reduce variability during Critic training
|
||||
# Get the predicted value
|
||||
vpred = train_model.vf
|
||||
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
|
||||
# Unclipped value
|
||||
vf_losses1 = tf.square(vpred - R)
|
||||
# Clipped value
|
||||
vf_losses2 = tf.square(vpredclipped - R)
|
||||
|
||||
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
|
||||
|
||||
# Calculate ratio (pi current policy / pi old policy)
|
||||
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
|
||||
|
||||
# Defining Loss = - J is equivalent to max J
|
||||
pg_losses = -ADV * ratio
|
||||
|
||||
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
|
||||
|
||||
# Final PG loss
|
||||
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
|
||||
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
|
||||
|
||||
# Total loss
|
||||
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
|
||||
|
||||
# UPDATE THE PARAMETERS USING LOSS
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables('ppo2_model')
|
||||
# 2. Build our trainer
|
||||
if MPI is not None:
|
||||
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
else:
|
||||
self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = self.trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads_and_var = list(zip(grads, var))
|
||||
# zip aggregate each gradient with parameters associated
|
||||
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
||||
|
||||
self.grads = grads
|
||||
self.var = var
|
||||
self._train_op = self.trainer.apply_gradients(grads_and_var)
|
||||
self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
|
||||
self.stats_list = [pg_loss, vf_loss, entropy, approxkl, clipfrac]
|
||||
|
||||
|
||||
self.train_model = train_model
|
||||
self.act_model = act_model
|
||||
self.step = act_model.step
|
||||
self.value = act_model.value
|
||||
self.initial_state = act_model.initial_state
|
||||
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
|
||||
initialize()
|
||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
|
||||
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
|
||||
td_map = {
|
||||
self.train_model.X : obs,
|
||||
self.A : actions,
|
||||
self.ADV : advs,
|
||||
self.R : returns,
|
||||
self.LR : lr,
|
||||
self.CLIPRANGE : cliprange,
|
||||
self.OLDNEGLOGPAC : neglogpacs,
|
||||
self.OLDVPRED : values
|
||||
}
|
||||
if states is not None:
|
||||
td_map[self.train_model.S] = states
|
||||
td_map[self.train_model.M] = masks
|
||||
|
||||
return self.sess.run(
|
||||
self.stats_list + [self._train_op],
|
||||
td_map
|
||||
)[:-1]
|
||||
|
@@ -1,17 +1,226 @@
|
||||
import os
|
||||
import time
|
||||
import functools
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import tensorflow as tf
|
||||
from baselines import logger
|
||||
from collections import deque
|
||||
from baselines.common import explained_variance, set_global_seeds
|
||||
from baselines.common.policies import build_policy
|
||||
from baselines.common.runners import AbstractEnvRunner
|
||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
|
||||
try:
|
||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||
from mpi4py import MPI
|
||||
from baselines.common.mpi_util import sync_from_root
|
||||
except ImportError:
|
||||
MPI = None
|
||||
from baselines.ppo2.runner import Runner
|
||||
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
We use this object to :
|
||||
__init__:
|
||||
- Creates the step_model
|
||||
- Creates the train_model
|
||||
|
||||
train():
|
||||
- Make the training part (feedforward and retropropagation of gradients)
|
||||
|
||||
save/load():
|
||||
- Save load the model
|
||||
"""
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm):
|
||||
sess = get_session()
|
||||
|
||||
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
|
||||
# CREATE OUR TWO MODELS
|
||||
# act_model that is used for sampling
|
||||
act_model = policy(nbatch_act, 1, sess)
|
||||
|
||||
# Train model for training
|
||||
train_model = policy(nbatch_train, nsteps, sess)
|
||||
|
||||
# CREATE THE PLACEHOLDERS
|
||||
A = train_model.pdtype.sample_placeholder([None])
|
||||
ADV = tf.placeholder(tf.float32, [None])
|
||||
R = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old actor
|
||||
OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old critic
|
||||
OLDVPRED = tf.placeholder(tf.float32, [None])
|
||||
LR = tf.placeholder(tf.float32, [])
|
||||
# Cliprange
|
||||
CLIPRANGE = tf.placeholder(tf.float32, [])
|
||||
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
|
||||
# Calculate the entropy
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
|
||||
# CALCULATE THE LOSS
|
||||
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
||||
|
||||
# Clip the value to reduce variability during Critic training
|
||||
# Get the predicted value
|
||||
vpred = train_model.vf
|
||||
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
|
||||
# Unclipped value
|
||||
vf_losses1 = tf.square(vpred - R)
|
||||
# Clipped value
|
||||
vf_losses2 = tf.square(vpredclipped - R)
|
||||
|
||||
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
|
||||
|
||||
# Calculate ratio (pi current policy / pi old policy)
|
||||
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
|
||||
|
||||
# Defining Loss = - J is equivalent to max J
|
||||
pg_losses = -ADV * ratio
|
||||
|
||||
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
|
||||
|
||||
# Final PG loss
|
||||
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
|
||||
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
|
||||
|
||||
# Total loss
|
||||
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
|
||||
|
||||
# UPDATE THE PARAMETERS USING LOSS
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables('ppo2_model')
|
||||
# 2. Build our trainer
|
||||
if MPI is not None:
|
||||
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
else:
|
||||
trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads_and_var = list(zip(grads, var))
|
||||
# zip aggregate each gradient with parameters associated
|
||||
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
||||
|
||||
_train = trainer.apply_gradients(grads_and_var)
|
||||
|
||||
def train(lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
|
||||
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
|
||||
if states is not None:
|
||||
td_map[train_model.S] = states
|
||||
td_map[train_model.M] = masks
|
||||
return sess.run(
|
||||
[pg_loss, vf_loss, entropy, approxkl, clipfrac, _train],
|
||||
td_map
|
||||
)[:-1]
|
||||
self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
|
||||
|
||||
|
||||
self.train = train
|
||||
self.train_model = train_model
|
||||
self.act_model = act_model
|
||||
self.step = act_model.step
|
||||
self.value = act_model.value
|
||||
self.initial_state = act_model.initial_state
|
||||
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
initialize()
|
||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
We use this object to make a mini batch of experiences
|
||||
__init__:
|
||||
- Initialize the runner
|
||||
|
||||
run():
|
||||
- Make a mini batch
|
||||
"""
|
||||
def __init__(self, *, env, model, nsteps, gamma, lam):
|
||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||
# Lambda used in GAE (General Advantage Estimation)
|
||||
self.lam = lam
|
||||
# Discount rate
|
||||
self.gamma = gamma
|
||||
|
||||
def run(self):
|
||||
# Here, we init the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
|
||||
mb_states = self.states
|
||||
epinfos = []
|
||||
# For n in range number of steps
|
||||
for _ in range(self.nsteps):
|
||||
# Given observations, get action value and neglopacs
|
||||
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
||||
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
|
||||
mb_obs.append(self.obs.copy())
|
||||
mb_actions.append(actions)
|
||||
mb_values.append(values)
|
||||
mb_neglogpacs.append(neglogpacs)
|
||||
mb_dones.append(self.dones)
|
||||
|
||||
# Take actions in env and look the results
|
||||
# Infos contains a ton of useful informations
|
||||
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
|
||||
for info in infos:
|
||||
maybeepinfo = info.get('episode')
|
||||
if maybeepinfo: epinfos.append(maybeepinfo)
|
||||
mb_rewards.append(rewards)
|
||||
#batch of steps to batch of rollouts
|
||||
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
|
||||
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
|
||||
mb_actions = np.asarray(mb_actions)
|
||||
mb_values = np.asarray(mb_values, dtype=np.float32)
|
||||
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
|
||||
mb_dones = np.asarray(mb_dones, dtype=np.bool)
|
||||
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
|
||||
|
||||
# discount/bootstrap off value fn
|
||||
mb_returns = np.zeros_like(mb_rewards)
|
||||
mb_advs = np.zeros_like(mb_rewards)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.nsteps)):
|
||||
if t == self.nsteps - 1:
|
||||
nextnonterminal = 1.0 - self.dones
|
||||
nextvalues = last_values
|
||||
else:
|
||||
nextnonterminal = 1.0 - mb_dones[t+1]
|
||||
nextvalues = mb_values[t+1]
|
||||
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t]
|
||||
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
|
||||
mb_returns = mb_advs + mb_values
|
||||
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)),
|
||||
mb_states, epinfos)
|
||||
# obs, returns, masks, actions, values, neglogpacs, states = runner.run()
|
||||
def sf01(arr):
|
||||
"""
|
||||
swap and then flatten axes 0 and 1
|
||||
"""
|
||||
s = arr.shape
|
||||
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
|
||||
|
||||
def constfn(val):
|
||||
def f(_):
|
||||
@@ -21,7 +230,7 @@ def constfn(val):
|
||||
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
|
||||
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
|
||||
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
|
||||
save_interval=0, load_path=None, model_fn=None, **network_kwargs):
|
||||
save_interval=0, load_path=None, **network_kwargs):
|
||||
'''
|
||||
Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)
|
||||
|
||||
@@ -99,14 +308,10 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
nbatch_train = nbatch // nminibatches
|
||||
|
||||
# Instantiate the model object (that creates act_model and train_model)
|
||||
if model_fn is None:
|
||||
from baselines.ppo2.model import Model
|
||||
model_fn = Model
|
||||
|
||||
model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
|
||||
make_model = lambda : Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
|
||||
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
|
||||
model = make_model()
|
||||
if load_path is not None:
|
||||
model.load(load_path)
|
||||
# Instantiate the runner object
|
||||
@@ -114,6 +319,8 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
if eval_env is not None:
|
||||
eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam)
|
||||
|
||||
|
||||
|
||||
epinfobuf = deque(maxlen=100)
|
||||
if eval_env is not None:
|
||||
eval_epinfobuf = deque(maxlen=100)
|
||||
|
@@ -1,76 +0,0 @@
|
||||
import numpy as np
|
||||
from baselines.common.runners import AbstractEnvRunner
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
We use this object to make a mini batch of experiences
|
||||
__init__:
|
||||
- Initialize the runner
|
||||
|
||||
run():
|
||||
- Make a mini batch
|
||||
"""
|
||||
def __init__(self, *, env, model, nsteps, gamma, lam):
|
||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||
# Lambda used in GAE (General Advantage Estimation)
|
||||
self.lam = lam
|
||||
# Discount rate
|
||||
self.gamma = gamma
|
||||
|
||||
def run(self):
|
||||
# Here, we init the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
|
||||
mb_states = self.states
|
||||
epinfos = []
|
||||
# For n in range number of steps
|
||||
for _ in range(self.nsteps):
|
||||
# Given observations, get action value and neglopacs
|
||||
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
||||
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
|
||||
mb_obs.append(self.obs.copy())
|
||||
mb_actions.append(actions)
|
||||
mb_values.append(values)
|
||||
mb_neglogpacs.append(neglogpacs)
|
||||
mb_dones.append(self.dones)
|
||||
|
||||
# Take actions in env and look the results
|
||||
# Infos contains a ton of useful informations
|
||||
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
|
||||
for info in infos:
|
||||
maybeepinfo = info.get('episode')
|
||||
if maybeepinfo: epinfos.append(maybeepinfo)
|
||||
mb_rewards.append(rewards)
|
||||
#batch of steps to batch of rollouts
|
||||
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
|
||||
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
|
||||
mb_actions = np.asarray(mb_actions)
|
||||
mb_values = np.asarray(mb_values, dtype=np.float32)
|
||||
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
|
||||
mb_dones = np.asarray(mb_dones, dtype=np.bool)
|
||||
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
|
||||
|
||||
# discount/bootstrap off value fn
|
||||
mb_returns = np.zeros_like(mb_rewards)
|
||||
mb_advs = np.zeros_like(mb_rewards)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.nsteps)):
|
||||
if t == self.nsteps - 1:
|
||||
nextnonterminal = 1.0 - self.dones
|
||||
nextvalues = last_values
|
||||
else:
|
||||
nextnonterminal = 1.0 - mb_dones[t+1]
|
||||
nextvalues = mb_values[t+1]
|
||||
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t]
|
||||
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
|
||||
mb_returns = mb_advs + mb_values
|
||||
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)),
|
||||
mb_states, epinfos)
|
||||
# obs, returns, masks, actions, values, neglogpacs, states = runner.run()
|
||||
def sf01(arr):
|
||||
"""
|
||||
swap and then flatten axes 0 and 1
|
||||
"""
|
||||
s = arr.shape
|
||||
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
|
||||
|
||||
|
@@ -1,34 +0,0 @@
|
||||
import gym
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common.tf_util import make_session
|
||||
from baselines.ppo2.ppo2 import learn
|
||||
|
||||
from baselines.ppo2.microbatched_model import MicrobatchedModel
|
||||
|
||||
def test_microbatches():
|
||||
def env_fn():
|
||||
env = gym.make('CartPole-v0')
|
||||
env.seed(0)
|
||||
return env
|
||||
|
||||
learn_fn = partial(learn, network='mlp', nsteps=32, total_timesteps=32, seed=0)
|
||||
|
||||
env_ref = DummyVecEnv([env_fn])
|
||||
sess_ref = make_session(make_default=True, graph=tf.Graph())
|
||||
learn_fn(env=env_ref)
|
||||
vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()}
|
||||
|
||||
env_test = DummyVecEnv([env_fn])
|
||||
sess_test = make_session(make_default=True, graph=tf.Graph())
|
||||
learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2))
|
||||
vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()}
|
||||
|
||||
for v in vars_ref:
|
||||
np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_microbatches()
|
@@ -5,7 +5,7 @@ matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
|
||||
import matplotlib.pyplot as plt
|
||||
plt.rcParams['svg.fonttype'] = 'none'
|
||||
|
||||
from baselines.common import plot_util
|
||||
from baselines.bench.monitor import load_results
|
||||
|
||||
X_TIMESTEPS = 'timesteps'
|
||||
X_EPISODES = 'episodes'
|
||||
@@ -16,7 +16,7 @@ POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
|
||||
EPISODES_WINDOW = 100
|
||||
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
|
||||
'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise',
|
||||
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue']
|
||||
'darkgreen', 'tan', 'salmon', 'gold', 'lightpurple', 'darkred', 'darkblue']
|
||||
|
||||
def rolling_window(a, window):
|
||||
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
|
||||
@@ -50,7 +50,7 @@ def plot_curves(xy_list, xaxis, yaxis, title):
|
||||
maxx = max(xy[0][-1] for xy in xy_list)
|
||||
minx = 0
|
||||
for (i, (x, y)) in enumerate(xy_list):
|
||||
color = COLORS[i % len(COLORS)]
|
||||
color = COLORS[i]
|
||||
plt.scatter(x, y, s=2)
|
||||
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) #So returns average of last EPISODE_WINDOW episodes
|
||||
plt.plot(x, y_mean, color=color)
|
||||
@@ -62,18 +62,19 @@ def plot_curves(xy_list, xaxis, yaxis, title):
|
||||
fig.canvas.mpl_connect('resize_event', lambda event: plt.tight_layout())
|
||||
plt.grid(True)
|
||||
|
||||
|
||||
def split_by_task(taskpath):
|
||||
return taskpath['dirname'].split('/')[-1].split('-')[0]
|
||||
|
||||
def plot_results(dirs, num_timesteps=10e6, xaxis=X_TIMESTEPS, yaxis=Y_REWARD, title='', split_fn=split_by_task):
|
||||
results = plot_util.load_results(dirs)
|
||||
plot_util.plot_results(results, xy_fn=lambda r: ts2xy(r['monitor'], xaxis, yaxis), split_fn=split_fn, average_group=True, resample=int(1e6))
|
||||
def plot_results(dirs, num_timesteps, xaxis, yaxis, task_name):
|
||||
tslist = []
|
||||
for dir in dirs:
|
||||
ts = load_results(dir)
|
||||
ts = ts[ts.l.cumsum() <= num_timesteps]
|
||||
tslist.append(ts)
|
||||
xy_list = [ts2xy(ts, xaxis, yaxis) for ts in tslist]
|
||||
plot_curves(xy_list, xaxis, yaxis, task_name)
|
||||
|
||||
# Example usage in jupyter-notebook
|
||||
# from baselines.results_plotter import plot_results
|
||||
# from baselines import log_viewer
|
||||
# %matplotlib inline
|
||||
# plot_results("./log")
|
||||
# log_viewer.plot_results(["./log"], 10e6, log_viewer.X_TIMESTEPS, "Breakout")
|
||||
# Here ./log is a directory containing the monitor.csv files
|
||||
|
||||
def main():
|
||||
|
@@ -6,7 +6,6 @@ from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||
from baselines.common.tf_util import get_session
|
||||
@@ -63,8 +62,6 @@ def train(args, extra_args):
|
||||
alg_kwargs.update(extra_args)
|
||||
|
||||
env = build_env(args)
|
||||
if args.save_video_interval != 0:
|
||||
env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
|
||||
|
||||
if args.network:
|
||||
alg_kwargs['network'] = args.network
|
||||
@@ -110,8 +107,7 @@ def build_env(args):
|
||||
config.gpu_options.allow_growth = True
|
||||
get_session(config=config)
|
||||
|
||||
flatten_dict_observations = alg not in {'her'}
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
|
||||
|
||||
if env_type == 'mujoco':
|
||||
env = VecNormalize(env)
|
||||
@@ -120,11 +116,6 @@ def build_env(args):
|
||||
|
||||
|
||||
def get_env_type(env_id):
|
||||
# Re-parse the gym registry, since we could have new envs since last time.
|
||||
for env in gym.envs.registry.all():
|
||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||
_game_envs[env_type].add(env.id) # This is a set so add is idempotent
|
||||
|
||||
if env_id in _game_envs.keys():
|
||||
env_type = env_id
|
||||
env_id = [g for g in _game_envs[env_type]][0]
|
||||
@@ -187,16 +178,13 @@ def parse_cmdline_kwargs(args):
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
def main():
|
||||
# configure logger, disable logging in child MPI processes (with rank > 0)
|
||||
|
||||
arg_parser = common_arg_parser()
|
||||
args, unknown_args = arg_parser.parse_known_args(args)
|
||||
args, unknown_args = arg_parser.parse_known_args()
|
||||
extra_args = parse_cmdline_kwargs(unknown_args)
|
||||
|
||||
if args.extra_import is not None:
|
||||
import_module(args.extra_import)
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
rank = 0
|
||||
logger.configure()
|
||||
@@ -215,16 +203,11 @@ def main(args):
|
||||
logger.log("Running trained model")
|
||||
env = build_env(args)
|
||||
obs = env.reset()
|
||||
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
dones = np.zeros((1,))
|
||||
|
||||
def initialize_placeholders(nlstm=128,**kwargs):
|
||||
return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
|
||||
state, dones = initialize_placeholders(**extra_args)
|
||||
while True:
|
||||
if state is not None:
|
||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||
else:
|
||||
actions, _, _, _ = model.step(obs)
|
||||
|
||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||
obs, _, done, _ = env.step(actions)
|
||||
env.render()
|
||||
done = done.any() if isinstance(done, np.ndarray) else done
|
||||
@@ -234,7 +217,5 @@ def main(args):
|
||||
|
||||
env.close()
|
||||
|
||||
return model
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv)
|
||||
main()
|
||||
|
File diff suppressed because one or more lines are too long
@@ -3,5 +3,6 @@ select = F,E999,W291,W293
|
||||
exclude =
|
||||
.git,
|
||||
__pycache__,
|
||||
baselines/her,
|
||||
baselines/ppo1,
|
||||
baselines/bench,
|
||||
|
7
setup.py
7
setup.py
@@ -11,7 +11,6 @@ extras = {
|
||||
'test': [
|
||||
'filelock',
|
||||
'pytest',
|
||||
'pytest-forked',
|
||||
'atari-py'
|
||||
],
|
||||
'bullet': [
|
||||
@@ -53,11 +52,11 @@ setup(name='baselines',
|
||||
# ensure there is some tensorflow build with version above 1.4
|
||||
import pkg_resources
|
||||
tf_pkg = None
|
||||
for tf_pkg_name in ['tensorflow', 'tensorflow-gpu', 'tf-nightly', 'tf-nightly-gpu']:
|
||||
for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
|
||||
try:
|
||||
tf_pkg = pkg_resources.get_distribution(tf_pkg_name)
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
|
||||
from distutils.version import LooseVersion
|
||||
assert LooseVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= LooseVersion('1.4.0')
|
||||
from distutils.version import StrictVersion
|
||||
assert StrictVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
||||
|
16
test.dockerfile.py36-mpi
Normal file
16
test.dockerfile.py36-mpi
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.6
|
||||
|
||||
RUN apt-get -y update && apt-get -y install ffmpeg libopenmpi-dev
|
||||
ENV CODE_DIR /root/code
|
||||
|
||||
COPY . $CODE_DIR/baselines
|
||||
WORKDIR $CODE_DIR/baselines
|
||||
|
||||
# Clean up pycache and pyc files
|
||||
RUN rm -rf __pycache__ && \
|
||||
find . -name "*.pyc" -delete && \
|
||||
pip install tensorflow && \
|
||||
pip install -e .[test,mpi]
|
||||
|
||||
|
||||
CMD /bin/bash
|
@@ -1,8 +1,6 @@
|
||||
FROM python:3.6
|
||||
|
||||
RUN apt-get -y update && apt-get -y install ffmpeg
|
||||
# RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
|
||||
|
||||
ENV CODE_DIR /root/code
|
||||
|
||||
COPY . $CODE_DIR/baselines
|
Reference in New Issue
Block a user