Compare commits

...

2 Commits

Author SHA1 Message Date
Matthias Plappert
d90638b565 Minor improvements 2018-02-26 10:32:11 +01:00
Matthias Plappert
f4953c3c2d Add Hindsight Experience Replay (HER) 2018-02-26 10:00:17 +01:00
23 changed files with 1740 additions and 37 deletions

View File

@@ -20,6 +20,7 @@ pip install -e .
- [DDPG](baselines/ddpg)
- [DQN](baselines/deepq)
- [GAIL](baselines/gail)
- [HER](baselines/her)
- [PPO1](baselines/ppo1) (Multi-CPU using MPI)
- [PPO2](baselines/ppo2) (Optimized for GPU)
- [TRPO](baselines/trpo_mpi)

View File

@@ -39,12 +39,24 @@ def ortho_init(scale=1.0):
return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
return _ortho_init
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0):
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC'):
if data_format == 'NHWC':
channel_ax = 3
strides = [1, stride, stride, 1]
bshape = [1, 1, 1, nf]
elif data_format == 'NCHW':
channel_ax = 1
strides = [1, 1, stride, stride]
bshape = [1, nf, 1, 1]
else:
raise NotImplementedError
nin = x.get_shape()[channel_ax].value
wshape = [rf, rf, nin, nf]
with tf.variable_scope(scope):
nin = x.get_shape()[3].value
w = tf.get_variable("w", [rf, rf, nin, nf], initializer=ortho_init(init_scale))
b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0.0))
return tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b
w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
b = tf.get_variable("b", [1, nf, 1, 1], initializer=tf.constant_initializer(0.0))
if data_format == 'NHWC': b = tf.reshape(b, bshape)
return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format)
def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
with tf.variable_scope(scope):

View File

@@ -7,12 +7,13 @@ from glob import glob
import csv
import os.path as osp
import json
import numpy as np
class Monitor(Wrapper):
EXT = "monitor.csv"
f = None
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=()):
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
Wrapper.__init__(self, env=env)
self.tstart = time.time()
if filename is None:
@@ -26,10 +27,12 @@ class Monitor(Wrapper):
filename = filename + "." + Monitor.EXT
self.f = open(filename, "wt")
self.f.write('#%s\n'%json.dumps({"t_start": self.tstart, 'env_id' : env.spec and env.spec.id}))
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords)
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords+info_keywords)
self.logger.writeheader()
self.f.flush()
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
self.rewards = None
self.needs_reset = True
@@ -61,6 +64,8 @@ class Monitor(Wrapper):
eprew = sum(self.rewards)
eplen = len(self.rewards)
epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)}
for k in self.info_keywords:
epinfo[k] = info[k]
self.episode_rewards.append(eprew)
self.episode_lengths.append(eplen)
self.episode_times.append(time.time() - self.tstart)

View File

@@ -4,6 +4,7 @@ Helpers for scripts like run_atari.py.
import os
import gym
from gym.wrappers import FlattenDictWrapper
from baselines import logger
from baselines.bench import Monitor
from baselines.common import set_global_seeds
@@ -36,6 +37,19 @@ def make_mujoco_env(env_id, seed):
env.seed(seed)
return env
def make_robotics_env(env_id, seed, rank=0):
"""
Create a wrapped, monitored gym.Env for MuJoCo.
"""
set_global_seeds(seed)
env = gym.make(env_id)
env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',))
env.seed(seed)
return env
def arg_parser():
"""
Create an empty argparse.ArgumentParser.
@@ -58,7 +72,17 @@ def mujoco_arg_parser():
Create an argparse.ArgumentParser for run_mujoco.py.
"""
parser = arg_parser()
parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1")
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
return parser
def robotics_arg_parser():
"""
Create an argparse.ArgumentParser for run_mujoco.py.
"""
parser = arg_parser()
parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
return parser

View File

@@ -16,7 +16,12 @@ def fmt_item(x, l):
if isinstance(x, np.ndarray):
assert x.ndim==0
x = x.item()
if isinstance(x, float): rep = "%g"%x
if isinstance(x, (float, np.float32, np.float64)):
v = abs(x)
if (v < 1e-4 or v > 1e+4) and v > 0:
rep = "%7.2e" % x
else:
rep = "%7.5f" % x
else: rep = str(x)
return " "*(l - len(rep)) + rep

View File

@@ -261,3 +261,20 @@ def get_placeholder_cached(name):
def flattenallbut0(x):
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
# ================================================================
# Diagnostics
# ================================================================
def display_var_info(vars):
from baselines import logger
count_params = 0
for v in vars:
name = v.name
if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
count_params += np.prod(v.shape.as_list())
if "/b:" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
logger.info(" %s%s%s" % (name, " "*(55-len(name)), str(v.shape)))
logger.info("Total model parameters: %0.1f million" % (count_params*1e-6))

View File

@@ -1,4 +1,5 @@
import numpy as np
import gym
from . import VecEnv
class DummyVecEnv(VecEnv):
@@ -6,26 +7,36 @@ class DummyVecEnv(VecEnv):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
self.ts = np.zeros(len(self.envs), dtype='int')
obs_spaces = self.observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (self.observation_space,)
self.buf_obs = [np.zeros((self.num_envs,) + tuple(s.shape), s.dtype) for s in obs_spaces]
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
self.actions = None
def step_async(self, actions):
self.actions = actions
def step_wait(self):
results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
obs, rews, dones, infos = map(np.array, zip(*results))
self.ts += 1
for (i, done) in enumerate(dones):
if done:
obs[i] = self.envs[i].reset()
self.ts[i] = 0
self.actions = None
return np.array(obs), np.array(rews), np.array(dones), infos
for i in range(self.num_envs):
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i])
if isinstance(obs_tuple, (tuple, list)):
for t,x in enumerate(obs_tuple):
self.buf_obs[t][i] = x
else:
self.buf_obs[0][i] = obs_tuple
return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos
def reset(self):
results = [env.reset() for env in self.envs]
return np.array(results)
for i in range(self.num_envs):
obs_tuple = self.envs[i].reset()
if isinstance(obs_tuple, (tuple, list)):
for t,x in enumerate(obs_tuple):
self.buf_obs[t][i] = x
else:
self.buf_obs[0][i] = obs_tuple
return self.buf_obs
def close(self):
return

35
baselines/her/README.md Normal file
View File

@@ -0,0 +1,35 @@
# Hindsight Experience Replay
For details on Hindsight Experience Replay (HER), please read the [paper](https://arxiv.org/pdf/1707.01495.pdf).
## How to use Hindsight Experience Replay
### Getting started
Training an agent is very simple:
```bash
python -m baselines.her.experiment.train
```
This will train a DDPG+HER agent on the `FetchReach` environment.
You should see the success rate go up quickly to `1.0`, which means that the agent achieves the
desired goal in 100% of the cases.
The training script logs other diagnostics as well and pickles the best policy so far (w.r.t. to its test success rate),
the latest policy, and, if enabled, a history of policies every K epochs.
To inspect what the agent has learned, use the play script:
```bash
python -m baselines.her.experiment.play /path/to/an/experiment/policy_best.pkl
```
You can try it right now with the results of the training step (the script prints out the path for you).
This should visualize the current policy for 10 episodes and will also print statistics.
### Advanced usage
The train script comes with advanced features like MPI support, that allows to scale across all cores of a single machine.
To see all available options, simply run this command:
```bash
python -m baselines.her.experiment.train --help
```
To run on, say, 20 CPU cores, you can use the following command:
```bash
python -m baselines.her.experiment.train --num_cpu 20
```
That's it, you are now running rollouts using 20 MPI workers and average gradients for network updates across all 20 core.

View File

View File

@@ -0,0 +1,44 @@
import tensorflow as tf
from baselines.her.util import store_args, nn
class ActorCritic:
@store_args
def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers,
**kwargs):
"""The actor-critic network and related training code.
Args:
inputs_tf (dict of tensors): all necessary inputs for the network: the
observation (o), the goal (g), and the action (u)
dimo (int): the dimension of the observations
dimg (int): the dimension of the goals
dimu (int): the dimension of the actions
max_u (float): the maximum magnitude of actions; action outputs will be scaled
accordingly
o_stats (baselines.her.Normalizer): normalizer for observations
g_stats (baselines.her.Normalizer): normalizer for goals
hidden (int): number of hidden units that should be used in hidden layers
layers (int): number of hidden layers
"""
self.o_tf = inputs_tf['o']
self.g_tf = inputs_tf['g']
self.u_tf = inputs_tf['u']
# Prepare inputs for actor and critic.
o = self.o_stats.normalize(self.o_tf)
g = self.g_stats.normalize(self.g_tf)
input_pi = tf.concat(axis=1, values=[o, g]) # for actor
# Networks.
with tf.variable_scope('pi'):
self.pi_tf = self.max_u * tf.tanh(nn(
input_pi, [self.hidden] * self.layers + [self.dimu]))
with tf.variable_scope('Q'):
# for policy training
input_Q = tf.concat(axis=1, values=[o, g, self.pi_tf / self.max_u])
self.Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1])
# for critic training
input_Q = tf.concat(axis=1, values=[o, g, self.u_tf / self.max_u])
self._input_Q = input_Q # exposed for tests
self.Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True)

340
baselines/her/ddpg.py Normal file
View File

@@ -0,0 +1,340 @@
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from tensorflow.contrib.staging import StagingArea
from baselines import logger
from baselines.her.util import (
import_function, store_args, flatten_grads, transitions_in_episode_batch)
from baselines.her.normalizer import Normalizer
from baselines.her.replay_buffer import ReplayBuffer
from baselines.common.mpi_adam import MpiAdam
def dims_to_shapes(input_dims):
return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()}
class DDPG(object):
@store_args
def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
sample_transitions, gamma, reuse=False, **kwargs):
"""Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
Args:
input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
actions (u)
buffer_size (int): number of transitions that are stored in the replay buffer
hidden (int): number of units in the hidden layers
layers (int): number of hidden layers
network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
polyak (float): coefficient for Polyak-averaging of the target network
batch_size (int): batch size for training
Q_lr (float): learning rate for the Q (critic) network
pi_lr (float): learning rate for the pi (actor) network
norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
action_l2 (float): coefficient for L2 penalty on the actions
clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
scope (str): the scope used for the TensorFlow graph
T (int): the time horizon for rollouts
rollout_batch_size (int): number of parallel rollouts per DDPG agent
subtract_goals (function): function that subtracts goals from each other
relative_goals (boolean): whether or not relative goals should be fed into the network
clip_pos_returns (boolean): whether or not positive returns should be clipped
clip_return (float): clip returns to be in [-clip_return, clip_return]
sample_transitions (function) function that samples from the replay buffer
gamma (float): gamma used for Q learning updates
reuse (boolean): whether or not the networks should be reused
"""
if self.clip_return is None:
self.clip_return = np.inf
self.create_actor_critic = import_function(self.network_class)
input_shapes = dims_to_shapes(self.input_dims)
self.dimo = self.input_dims['o']
self.dimg = self.input_dims['g']
self.dimu = self.input_dims['u']
# Prepare staging area for feeding data to the model.
stage_shapes = OrderedDict()
for key in sorted(self.input_dims.keys()):
if key.startswith('info_'):
continue
stage_shapes[key] = (None, *input_shapes[key])
for key in ['o', 'g']:
stage_shapes[key + '_2'] = stage_shapes[key]
stage_shapes['r'] = (None,)
self.stage_shapes = stage_shapes
# Create network.
with tf.variable_scope(self.scope):
self.staging_tf = StagingArea(
dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
shapes=list(self.stage_shapes.values()))
self.buffer_ph_tf = [
tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
self._create_network(reuse=reuse)
# Configure the replay buffer.
buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
for key, val in input_shapes.items()}
buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
buffer_shapes['ag'] = (self.T+1, self.dimg)
buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
def _random_action(self, n):
return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))
def _preprocess_og(self, o, ag, g):
if self.relative_goals:
g_shape = g.shape
g = g.reshape(-1, self.dimg)
ag = ag.reshape(-1, self.dimg)
g = self.subtract_goals(g, ag)
g = g.reshape(*g_shape)
o = np.clip(o, -self.clip_obs, self.clip_obs)
g = np.clip(g, -self.clip_obs, self.clip_obs)
return o, g
def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
compute_Q=False):
o, g = self._preprocess_og(o, ag, g)
policy = self.target if use_target_net else self.main
# values to compute
vals = [policy.pi_tf]
if compute_Q:
vals += [policy.Q_pi_tf]
# feed
feed = {
policy.o_tf: o.reshape(-1, self.dimo),
policy.g_tf: g.reshape(-1, self.dimg),
policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
}
ret = self.sess.run(vals, feed_dict=feed)
# action postprocessing
u = ret[0]
noise = noise_eps * self.max_u * np.random.randn(*u.shape) # gaussian noise
u += noise
u = np.clip(u, -self.max_u, self.max_u)
u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u) # eps-greedy
if u.shape[0] == 1:
u = u[0]
u = u.copy()
ret[0] = u
if len(ret) == 1:
return ret[0]
else:
return ret
def store_episode(self, episode_batch, update_stats=True):
"""
episode_batch: array of batch_size x (T or T+1) x dim_key
'o' is of size T+1, others are of size T
"""
self.buffer.store_episode(episode_batch)
if update_stats:
# add transitions to normalizer
episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)
o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
# No need to preprocess the o_2 and g_2 since this is only used for stats
self.o_stats.update(transitions['o'])
self.g_stats.update(transitions['g'])
self.o_stats.recompute_stats()
self.g_stats.recompute_stats()
def get_current_buffer_size(self):
return self.buffer.get_current_size()
def _sync_optimizers(self):
self.Q_adam.sync()
self.pi_adam.sync()
def _grads(self):
# Avoid feed_dict here for performance!
critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
self.Q_loss_tf,
self.main.Q_pi_tf,
self.Q_grad_tf,
self.pi_grad_tf
])
return critic_loss, actor_loss, Q_grad, pi_grad
def _update(self, Q_grad, pi_grad):
self.Q_adam.update(Q_grad, self.Q_lr)
self.pi_adam.update(pi_grad, self.pi_lr)
def sample_batch(self):
transitions = self.buffer.sample(self.batch_size)
o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
ag, ag_2 = transitions['ag'], transitions['ag_2']
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)
transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
return transitions_batch
def stage_batch(self, batch=None):
if batch is None:
batch = self.sample_batch()
assert len(self.buffer_ph_tf) == len(batch)
self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))
def train(self, stage=True):
if stage:
self.stage_batch()
critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
self._update(Q_grad, pi_grad)
return critic_loss, actor_loss
def _init_target_net(self):
self.sess.run(self.init_target_net_op)
def update_target_net(self):
self.sess.run(self.update_target_net_op)
def clear_buffer(self):
self.buffer.clear_buffer()
def _vars(self, scope):
res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
assert len(res) > 0
return res
def _global_vars(self, scope):
res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
return res
def _create_network(self, reuse=False):
logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))
self.sess = tf.get_default_session()
if self.sess is None:
self.sess = tf.InteractiveSession()
# running averages
with tf.variable_scope('o_stats') as vs:
if reuse:
vs.reuse_variables()
self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
with tf.variable_scope('g_stats') as vs:
if reuse:
vs.reuse_variables()
self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)
# mini-batch sampling.
batch = self.staging_tf.get()
batch_tf = OrderedDict([(key, batch[i])
for i, key in enumerate(self.stage_shapes.keys())])
batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
# networks
with tf.variable_scope('main') as vs:
if reuse:
vs.reuse_variables()
self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
vs.reuse_variables()
with tf.variable_scope('target') as vs:
if reuse:
vs.reuse_variables()
target_batch_tf = batch_tf.copy()
target_batch_tf['o'] = batch_tf['o_2']
target_batch_tf['g'] = batch_tf['g_2']
self.target = self.create_actor_critic(
target_batch_tf, net_type='target', **self.__dict__)
vs.reuse_variables()
assert len(self._vars("main")) == len(self._vars("target"))
# loss functions
target_Q_pi_tf = self.target.Q_pi_tf
clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
assert len(self._vars('main/Q')) == len(Q_grads_tf)
assert len(self._vars('main/pi')) == len(pi_grads_tf)
self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))
# optimizers
self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)
# polyak averaging
self.main_vars = self._vars('main/Q') + self._vars('main/pi')
self.target_vars = self._vars('target/Q') + self._vars('target/pi')
self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
self.init_target_net_op = list(
map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
self.update_target_net_op = list(
map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))
# initialize all variables
tf.variables_initializer(self._global_vars('')).run()
self._sync_optimizers()
self._init_target_net()
def logs(self, prefix=''):
logs = []
logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
if prefix is not '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs]
else:
return logs
def __getstate__(self):
"""Our policies can be loaded from pkl, but after unpickling you cannot continue training.
"""
excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
'main', 'target', 'lock', 'env', 'sample_transitions',
'stage_shapes', 'create_actor_critic']
state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
state['buffer_size'] = self.buffer_size
state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name])
return state
def __setstate__(self, state):
if 'sample_transitions' not in state:
# We don't need this for playing the policy.
state['sample_transitions'] = None
self.__init__(**state)
# set up stats (they are overwritten in __init__)
for k, v in state.items():
if k[-6:] == '_stats':
self.__dict__[k] = v
# load TF variables
vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
assert(len(vars) == len(state["tf"]))
node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
self.sess.run(node)

View File

View File

@@ -0,0 +1,170 @@
from copy import deepcopy
import numpy as np
import json
import os
import gym
from baselines import logger
from baselines.her.ddpg import DDPG
from baselines.her.her import make_sample_her_transitions
DEFAULT_ENV_PARAMS = {
'FetchReach-v0': {
'n_cycles': 10,
},
}
DEFAULT_PARAMS = {
# env
'max_u': 1., # max absolute value of actions on different coordinates
# ddpg
'layers': 3, # number of layers in the critic/actor networks
'hidden': 256, # number of neurons in each hidden layers
'network_class': 'baselines.her.actor_critic:ActorCritic',
'Q_lr': 0.001, # critic learning rate
'pi_lr': 0.001, # actor learning rate
'buffer_size': int(1E6), # for experience replay
'polyak': 0.95, # polyak averaging coefficient
'action_l2': 1.0, # quadratic penalty on actions (before rescaling by max_u)
'clip_obs': 200.,
'scope': 'ddpg', # can be tweaked for testing
'relative_goals': False,
# training
'n_cycles': 50, # per epoch
'rollout_batch_size': 2, # per mpi thread
'n_batches': 40, # training batches per cycle
'batch_size': 256, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length.
'n_test_rollouts': 10, # number of test rollouts per epoch, each consists of rollout_batch_size rollouts
'test_with_polyak': False, # run test episodes with the target network
# exploration
'random_eps': 0.3, # percentage of time a random action is taken
'noise_eps': 0.2, # std of gaussian noise added to not-completely-random actions as a percentage of max_u
# HER
'replay_strategy': 'future', # supported modes: future, none
'replay_k': 4, # number of additional goals used for replay, only used if off_policy_data=future
# normalization
'norm_eps': 0.01, # epsilon used for observation normalization
'norm_clip': 5, # normalized observations are cropped to this values
}
CACHED_ENVS = {}
def cached_make_env(make_env):
"""
Only creates a new environment from the provided function if one has not yet already been
created. This is useful here because we need to infer certain properties of the env, e.g.
its observation and action spaces, without any intend of actually using it.
"""
if make_env not in CACHED_ENVS:
env = make_env()
CACHED_ENVS[make_env] = env
return CACHED_ENVS[make_env]
def prepare_params(kwargs):
# DDPG params
ddpg_params = dict()
env_name = kwargs['env_name']
def make_env():
return gym.make(env_name)
kwargs['make_env'] = make_env
tmp_env = cached_make_env(kwargs['make_env'])
assert hasattr(tmp_env, '_max_episode_steps')
kwargs['T'] = tmp_env._max_episode_steps
tmp_env.reset()
kwargs['max_u'] = np.array(kwargs['max_u']) if type(kwargs['max_u']) == list else kwargs['max_u']
kwargs['gamma'] = 1. - 1. / kwargs['T']
if 'lr' in kwargs:
kwargs['pi_lr'] = kwargs['lr']
kwargs['Q_lr'] = kwargs['lr']
del kwargs['lr']
for name in ['buffer_size', 'hidden', 'layers',
'network_class',
'polyak',
'batch_size', 'Q_lr', 'pi_lr',
'norm_eps', 'norm_clip', 'max_u',
'action_l2', 'clip_obs', 'scope', 'relative_goals']:
ddpg_params[name] = kwargs[name]
kwargs['_' + name] = kwargs[name]
del kwargs[name]
kwargs['ddpg_params'] = ddpg_params
return kwargs
def log_params(params, logger=logger):
for key in sorted(params.keys()):
logger.info('{}: {}'.format(key, params[key]))
def configure_her(params):
env = cached_make_env(params['make_env'])
env.reset()
def reward_fun(ag_2, g, info): # vectorized
return env.compute_reward(achieved_goal=ag_2, desired_goal=g, info=info)
# Prepare configuration for HER.
her_params = {
'reward_fun': reward_fun,
}
for name in ['replay_strategy', 'replay_k']:
her_params[name] = params[name]
params['_' + name] = her_params[name]
del params[name]
sample_her_transitions = make_sample_her_transitions(**her_params)
return sample_her_transitions
def simple_goal_subtract(a, b):
assert a.shape == b.shape
return a - b
def configure_ddpg(dims, params, reuse=False, use_mpi=True, clip_return=True):
sample_her_transitions = configure_her(params)
# Extract relevant parameters.
gamma = params['gamma']
rollout_batch_size = params['rollout_batch_size']
ddpg_params = params['ddpg_params']
input_dims = dims.copy()
# DDPG agent
env = cached_make_env(params['make_env'])
env.reset()
ddpg_params.update({'input_dims': input_dims, # agent takes an input observations
'T': params['T'],
'clip_pos_returns': True, # clip positive returns
'clip_return': (1. / (1. - gamma)) if clip_return else np.inf, # max abs of return
'rollout_batch_size': rollout_batch_size,
'subtract_goals': simple_goal_subtract,
'sample_transitions': sample_her_transitions,
'gamma': gamma,
})
ddpg_params['info'] = {
'env_name': params['env_name'],
}
policy = DDPG(reuse=reuse, **ddpg_params, use_mpi=use_mpi)
return policy
def configure_dims(params):
env = cached_make_env(params['make_env'])
env.reset()
obs, _, _, info = env.step(env.action_space.sample())
dims = {
'o': obs['observation'].shape[0],
'u': env.action_space.shape[0],
'g': obs['desired_goal'].shape[0],
}
for key, value in info.items():
value = np.array(value)
if value.ndim == 0:
value = value.reshape(1)
dims['info_{}'.format(key)] = value.shape[0]
return dims

View File

@@ -0,0 +1,60 @@
import click
import numpy as np
import pickle
from baselines import logger
from baselines.common import set_global_seeds
import baselines.her.experiment.config as config
from baselines.her.rollout import RolloutWorker
@click.command()
@click.argument('policy_file', type=str)
@click.option('--seed', type=int, default=0)
@click.option('--n_test_rollouts', type=int, default=10)
@click.option('--render', type=int, default=1)
def main(policy_file, seed, n_test_rollouts, render):
set_global_seeds(seed)
# Load policy.
with open(policy_file, 'rb') as f:
policy = pickle.load(f)
env_name = policy.info['env_name']
# Prepare params.
params = config.DEFAULT_PARAMS
if env_name in config.DEFAULT_ENV_PARAMS:
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
params['env_name'] = env_name
params = config.prepare_params(params)
config.log_params(params, logger=logger)
dims = config.configure_dims(params)
eval_params = {
'exploit': True,
'use_target_net': params['test_with_polyak'],
'compute_Q': True,
'rollout_batch_size': 1,
'render': bool(render),
}
for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
eval_params[name] = params[name]
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
evaluator.seed(seed)
# Run evaluation.
evaluator.clear_history()
for _ in range(n_test_rollouts):
evaluator.generate_rollouts()
# record logs
for key, val in evaluator.logs('test'):
logger.record_tabular(key, np.mean(val))
logger.dump_tabular()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,118 @@
import os
import matplotlib.pyplot as plt
import numpy as np
import json
import seaborn as sns; sns.set()
import glob2
import argparse
def smooth_reward_curve(x, y):
halfwidth = int(np.ceil(len(x) / 60)) # Halfwidth of our smoothing convolution
k = halfwidth
xsmoo = x
ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='same') / np.convolve(np.ones_like(y), np.ones(2 * k + 1),
mode='same')
return xsmoo, ysmoo
def load_results(file):
if not os.path.exists(file):
return None
with open(file, 'r') as f:
lines = [line for line in f]
if len(lines) < 2:
return None
keys = [name.strip() for name in lines[0].split(',')]
data = np.genfromtxt(file, delimiter=',', skip_header=1, filling_values=0.)
if data.ndim == 1:
data = data.reshape(1, -1)
assert data.ndim == 2
assert data.shape[-1] == len(keys)
result = {}
for idx, key in enumerate(keys):
result[key] = data[:, idx]
return result
def pad(xs, value=np.nan):
maxlen = np.max([len(x) for x in xs])
padded_xs = []
for x in xs:
if x.shape[0] >= maxlen:
padded_xs.append(x)
padding = np.ones((maxlen - x.shape[0],) + x.shape[1:]) * value
x_padded = np.concatenate([x, padding], axis=0)
assert x_padded.shape[1:] == x.shape[1:]
assert x_padded.shape[0] == maxlen
padded_xs.append(x_padded)
return np.array(padded_xs)
parser = argparse.ArgumentParser()
parser.add_argument('dir', type=str)
parser.add_argument('--smooth', type=int, default=1)
args = parser.parse_args()
# Load all data.
data = {}
paths = [os.path.abspath(os.path.join(path, '..')) for path in glob2.glob(os.path.join(args.dir, '**', 'progress.csv'))]
for curr_path in paths:
if not os.path.isdir(curr_path):
continue
results = load_results(os.path.join(curr_path, 'progress.csv'))
if not results:
print('skipping {}'.format(curr_path))
continue
print('loading {} ({})'.format(curr_path, len(results['epoch'])))
with open(os.path.join(curr_path, 'metadata.json'), 'r') as f:
metadata = json.load(f)
success_rate = np.array(results['test/success_rate'])
epoch = np.array(results['epoch']) + 1
env_id = metadata['kwargs']['env_name']
replay_strategy = metadata['kwargs']['replay_strategy']
if replay_strategy == 'future':
config = 'her'
else:
config = 'ddpg'
if 'Dense' in env_id:
config += '-dense'
else:
config += '-sparse'
env_id = env_id.replace('Dense', '')
# Process and smooth data.
assert success_rate.shape == epoch.shape
x = epoch
y = success_rate
if args.smooth:
x, y = smooth_reward_curve(epoch, success_rate)
assert x.shape == y.shape
if env_id not in data:
data[env_id] = {}
if config not in data[env_id]:
data[env_id][config] = []
data[env_id][config].append((x, y))
# Plot data.
for env_id in sorted(data.keys()):
print('exporting {}'.format(env_id))
plt.clf()
for config in sorted(data[env_id].keys()):
xs, ys = zip(*data[env_id][config])
xs, ys = pad(xs), pad(ys)
assert xs.shape == ys.shape
plt.plot(xs[0], np.nanmedian(ys, axis=0), label=config)
plt.fill_between(xs[0], np.nanpercentile(ys, 25, axis=0), np.nanpercentile(ys, 75, axis=0), alpha=0.25)
plt.title(env_id)
plt.xlabel('Epoch')
plt.ylabel('Median Success Rate')
plt.legend()
plt.savefig(os.path.join(args.dir, 'fig_{}.png'.format(env_id)))

View File

@@ -0,0 +1,172 @@
import os
import sys
import click
import numpy as np
import json
from mpi4py import MPI
from baselines import logger
from baselines.common import set_global_seeds
from baselines.common.mpi_moments import mpi_moments
import baselines.her.experiment.config as config
from baselines.her.rollout import RolloutWorker
from baselines.her.util import mpi_fork
def mpi_average(value):
if value == []:
value = [0.]
if not isinstance(value, list):
value = [value]
return mpi_moments(np.array(value))[0]
def train(policy, rollout_worker, evaluator,
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
save_policies, **kwargs):
rank = MPI.COMM_WORLD.Get_rank()
latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl')
periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl')
logger.info("Training...")
best_success_rate = -1
for epoch in range(n_epochs):
# train
rollout_worker.clear_history()
for _ in range(n_cycles):
episode = rollout_worker.generate_rollouts()
policy.store_episode(episode)
for _ in range(n_batches):
policy.train()
policy.update_target_net()
# test
evaluator.clear_history()
for _ in range(n_test_rollouts):
evaluator.generate_rollouts()
# record logs
logger.record_tabular('epoch', epoch)
for key, val in evaluator.logs('test'):
logger.record_tabular(key, mpi_average(val))
for key, val in rollout_worker.logs('train'):
logger.record_tabular(key, mpi_average(val))
for key, val in policy.logs():
logger.record_tabular(key, mpi_average(val))
if rank == 0:
logger.dump_tabular()
# save the policy if it's better than the previous ones
success_rate = mpi_average(evaluator.current_success_rate())
if rank == 0 and success_rate >= best_success_rate and save_policies:
best_success_rate = success_rate
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
evaluator.save_policy(best_policy_path)
evaluator.save_policy(latest_policy_path)
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies:
policy_path = periodic_policy_path.format(epoch)
logger.info('Saving periodic policy to {} ...'.format(policy_path))
evaluator.save_policy(policy_path)
# make sure that different threads have different seeds
local_uniform = np.random.uniform(size=(1,))
root_uniform = local_uniform.copy()
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
if rank != 0:
assert local_uniform[0] != root_uniform[0]
def launch(
env_name, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return,
override_params={}, save_policies=True
):
# Fork for multi-CPU MPI implementation.
if num_cpu > 1:
whoami = mpi_fork(num_cpu)
if whoami == 'parent':
sys.exit(0)
import baselines.common.tf_util as U
U.single_threaded_session().__enter__()
rank = MPI.COMM_WORLD.Get_rank()
# Configure logging
if rank == 0:
if logdir or logger.get_dir() is None:
logger.configure(dir=logdir)
else:
logger.configure()
logdir = logger.get_dir()
assert logdir is not None
os.makedirs(logdir, exist_ok=True)
# Seed everything.
rank_seed = seed + 1000000 * rank
set_global_seeds(rank_seed)
# Prepare params.
params = config.DEFAULT_PARAMS
params['env_name'] = env_name
params['replay_strategy'] = replay_strategy
if env_name in config.DEFAULT_ENV_PARAMS:
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
params.update(**override_params) # makes it possible to override any parameter
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
json.dump(params, f)
params = config.prepare_params(params)
config.log_params(params, logger=logger)
dims = config.configure_dims(params)
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
rollout_params = {
'exploit': False,
'use_target_net': False,
'use_demo_states': True,
'compute_Q': False,
'T': params['T'],
}
eval_params = {
'exploit': True,
'use_target_net': params['test_with_polyak'],
'use_demo_states': False,
'compute_Q': True,
'T': params['T'],
}
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
rollout_params[name] = params[name]
eval_params[name] = params[name]
rollout_worker = RolloutWorker(params['make_env'], policy, dims, logger, **rollout_params)
rollout_worker.seed(rank_seed)
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
evaluator.seed(rank_seed)
train(
logdir=logdir, policy=policy, rollout_worker=rollout_worker,
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
policy_save_interval=policy_save_interval, save_policies=save_policies)
@click.command()
@click.option('--env_name', type=str, default='FetchReach-v0', help='the name of the OpenAI Gym environment that you want to train on')
@click.option('--logdir', type=str, default=None, help='the path to where logs and policy pickles should go. If not specified, creates a folder in /tmp/')
@click.option('--n_epochs', type=int, default=50, help='the number of training epochs to run')
@click.option('--num_cpu', type=int, default=1, help='the number of CPU cores to use (using MPI)')
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
def main(**kwargs):
launch(**kwargs)
if __name__ == '__main__':
main()

63
baselines/her/her.py Normal file
View File

@@ -0,0 +1,63 @@
import numpy as np
def make_sample_her_transitions(replay_strategy, replay_k, reward_fun):
"""Creates a sample function that can be used for HER experience replay.
Args:
replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none',
regular DDPG experience replay is used
replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times
as many HER replays as regular replays are used)
reward_fun (function): function to re-compute the reward with substituted goals
"""
if replay_strategy == 'future':
future_p = 1 - (1. / (1 + replay_k))
else: # 'replay_strategy' == 'none'
future_p = 0
def _sample_her_transitions(episode_batch, batch_size_in_transitions):
"""episode_batch is {key: array(buffer_size x T x dim_key)}
"""
T = episode_batch['u'].shape[1]
rollout_batch_size = episode_batch['u'].shape[0]
batch_size = batch_size_in_transitions
# Select which episodes and time steps to use.
episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
t_samples = np.random.randint(T, size=batch_size)
transitions = {key: episode_batch[key][episode_idxs, t_samples].copy()
for key in episode_batch.keys()}
# Select future time indexes proportional with probability future_p. These
# will be used for HER replay by substituting in future goals.
her_indexes = np.where(np.random.uniform(size=batch_size) < future_p)
future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
future_offset = future_offset.astype(int)
future_t = (t_samples + 1 + future_offset)[her_indexes]
# Replace goal with achieved goal but only for the previously-selected
# HER transitions (as defined by her_indexes). For the other transitions,
# keep the original goal.
future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
transitions['g'][her_indexes] = future_ag
# Reconstruct info dictionary for reward computation.
info = {}
for key, value in transitions.items():
if key.startswith('info_'):
info[key.replace('info_', '')] = value
# Re-compute reward since we may have substituted the goal.
reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
reward_params['info'] = info
transitions['r'] = reward_fun(**reward_params)
transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
for k in transitions.keys()}
assert(transitions['u'].shape[0] == batch_size_in_transitions)
return transitions
return _sample_her_transitions

140
baselines/her/normalizer.py Normal file
View File

@@ -0,0 +1,140 @@
import threading
import numpy as np
from mpi4py import MPI
import tensorflow as tf
from baselines.her.util import reshape_for_broadcasting
class Normalizer:
def __init__(self, size, eps=1e-2, default_clip_range=np.inf, sess=None):
"""A normalizer that ensures that observations are approximately distributed according to
a standard Normal distribution (i.e. have mean zero and variance one).
Args:
size (int): the size of the observation to be normalized
eps (float): a small constant that avoids underflows
default_clip_range (float): normalized observations are clipped to be in
[-default_clip_range, default_clip_range]
sess (object): the TensorFlow session to be used
"""
self.size = size
self.eps = eps
self.default_clip_range = default_clip_range
self.sess = sess if sess is not None else tf.get_default_session()
self.local_sum = np.zeros(self.size, np.float32)
self.local_sumsq = np.zeros(self.size, np.float32)
self.local_count = np.zeros(1, np.float32)
self.sum_tf = tf.get_variable(
initializer=tf.zeros_initializer(), shape=self.local_sum.shape, name='sum',
trainable=False, dtype=tf.float32)
self.sumsq_tf = tf.get_variable(
initializer=tf.zeros_initializer(), shape=self.local_sumsq.shape, name='sumsq',
trainable=False, dtype=tf.float32)
self.count_tf = tf.get_variable(
initializer=tf.ones_initializer(), shape=self.local_count.shape, name='count',
trainable=False, dtype=tf.float32)
self.mean = tf.get_variable(
initializer=tf.zeros_initializer(), shape=(self.size,), name='mean',
trainable=False, dtype=tf.float32)
self.std = tf.get_variable(
initializer=tf.ones_initializer(), shape=(self.size,), name='std',
trainable=False, dtype=tf.float32)
self.count_pl = tf.placeholder(name='count_pl', shape=(1,), dtype=tf.float32)
self.sum_pl = tf.placeholder(name='sum_pl', shape=(self.size,), dtype=tf.float32)
self.sumsq_pl = tf.placeholder(name='sumsq_pl', shape=(self.size,), dtype=tf.float32)
self.update_op = tf.group(
self.count_tf.assign_add(self.count_pl),
self.sum_tf.assign_add(self.sum_pl),
self.sumsq_tf.assign_add(self.sumsq_pl)
)
self.recompute_op = tf.group(
tf.assign(self.mean, self.sum_tf / self.count_tf),
tf.assign(self.std, tf.sqrt(tf.maximum(
tf.square(self.eps),
self.sumsq_tf / self.count_tf - tf.square(self.sum_tf / self.count_tf)
))),
)
self.lock = threading.Lock()
def update(self, v):
v = v.reshape(-1, self.size)
with self.lock:
self.local_sum += v.sum(axis=0)
self.local_sumsq += (np.square(v)).sum(axis=0)
self.local_count[0] += v.shape[0]
def normalize(self, v, clip_range=None):
if clip_range is None:
clip_range = self.default_clip_range
mean = reshape_for_broadcasting(self.mean, v)
std = reshape_for_broadcasting(self.std, v)
return tf.clip_by_value((v - mean) / std, -clip_range, clip_range)
def denormalize(self, v):
mean = reshape_for_broadcasting(self.mean, v)
std = reshape_for_broadcasting(self.std, v)
return mean + v * std
def _mpi_average(self, x):
buf = np.zeros_like(x)
MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM)
buf /= MPI.COMM_WORLD.Get_size()
return buf
def synchronize(self, local_sum, local_sumsq, local_count, root=None):
local_sum[...] = self._mpi_average(local_sum)
local_sumsq[...] = self._mpi_average(local_sumsq)
local_count[...] = self._mpi_average(local_count)
return local_sum, local_sumsq, local_count
def recompute_stats(self):
with self.lock:
# Copy over results.
local_count = self.local_count.copy()
local_sum = self.local_sum.copy()
local_sumsq = self.local_sumsq.copy()
# Reset.
self.local_count[...] = 0
self.local_sum[...] = 0
self.local_sumsq[...] = 0
# We perform the synchronization outside of the lock to keep the critical section as short
# as possible.
synced_sum, synced_sumsq, synced_count = self.synchronize(
local_sum=local_sum, local_sumsq=local_sumsq, local_count=local_count)
self.sess.run(self.update_op, feed_dict={
self.count_pl: synced_count,
self.sum_pl: synced_sum,
self.sumsq_pl: synced_sumsq,
})
self.sess.run(self.recompute_op)
class IdentityNormalizer:
def __init__(self, size, std=1.):
self.size = size
self.mean = tf.zeros(self.size, tf.float32)
self.std = std * tf.ones(self.size, tf.float32)
def update(self, x):
pass
def normalize(self, x, clip_range=None):
return x / self.std
def denormalize(self, x):
return self.std * x
def synchronize(self):
pass
def recompute_stats(self):
pass

View File

@@ -0,0 +1,108 @@
import threading
import numpy as np
class ReplayBuffer:
def __init__(self, buffer_shapes, size_in_transitions, T, sample_transitions):
"""Creates a replay buffer.
Args:
buffer_shapes (dict of ints): the shape for all buffers that are used in the replay
buffer
size_in_transitions (int): the size of the buffer, measured in transitions
T (int): the time horizon for episodes
sample_transitions (function): a function that samples from the replay buffer
"""
self.buffer_shapes = buffer_shapes
self.size = size_in_transitions // T
self.T = T
self.sample_transitions = sample_transitions
# self.buffers is {key: array(size_in_episodes x T or T+1 x dim_key)}
self.buffers = {key: np.empty([self.size, *shape])
for key, shape in buffer_shapes.items()}
# memory management
self.current_size = 0
self.n_transitions_stored = 0
self.lock = threading.Lock()
@property
def full(self):
with self.lock:
return self.current_size == self.size
def sample(self, batch_size):
"""Returns a dict {key: array(batch_size x shapes[key])}
"""
buffers = {}
with self.lock:
assert self.current_size > 0
for key in self.buffers.keys():
buffers[key] = self.buffers[key][:self.current_size]
buffers['o_2'] = buffers['o'][:, 1:, :]
buffers['ag_2'] = buffers['ag'][:, 1:, :]
transitions = self.sample_transitions(buffers, batch_size)
for key in (['r', 'o_2', 'ag_2'] + list(self.buffers.keys())):
assert key in transitions, "key %s missing from transitions" % key
return transitions
def store_episode(self, episode_batch):
"""episode_batch: array(batch_size x (T or T+1) x dim_key)
"""
batch_sizes = [len(episode_batch[key]) for key in episode_batch.keys()]
assert np.all(np.array(batch_sizes) == batch_sizes[0])
batch_size = batch_sizes[0]
with self.lock:
idxs = self._get_storage_idx(batch_size)
# load inputs into buffers
for key in self.buffers.keys():
self.buffers[key][idxs] = episode_batch[key]
self.n_transitions_stored += batch_size * self.T
def get_current_episode_size(self):
with self.lock:
return self.current_size
def get_current_size(self):
with self.lock:
return self.current_size * self.T
def get_transitions_stored(self):
with self.lock:
return self.n_transitions_stored
def clear_buffer(self):
with self.lock:
self.current_size = 0
def _get_storage_idx(self, inc=None):
inc = inc or 1 # size increment
assert inc <= self.size, "Batch committed to replay is too large!"
# go consecutively until you hit the end, and then go randomly.
if self.current_size+inc <= self.size:
idx = np.arange(self.current_size, self.current_size+inc)
elif self.current_size < self.size:
overflow = inc - (self.size - self.current_size)
idx_a = np.arange(self.current_size, self.size)
idx_b = np.random.randint(0, self.current_size, overflow)
idx = np.concatenate([idx_a, idx_b])
else:
idx = np.random.randint(0, self.size, inc)
# update replay size
self.current_size = min(self.size, self.current_size+inc)
if inc == 1:
idx = idx[0]
return idx

188
baselines/her/rollout.py Normal file
View File

@@ -0,0 +1,188 @@
from collections import deque
import numpy as np
import pickle
from mujoco_py import MujocoException
from baselines.her.util import convert_episode_to_batch_major, store_args
class RolloutWorker:
@store_args
def __init__(self, make_env, policy, dims, logger, T, rollout_batch_size=1,
exploit=False, use_target_net=False, compute_Q=False, noise_eps=0,
random_eps=0, history_len=100, render=False, **kwargs):
"""Rollout worker generates experience by interacting with one or many environments.
Args:
make_env (function): a factory function that creates a new instance of the environment
when called
policy (object): the policy that is used to act
dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u)
logger (object): the logger that is used by the rollout worker
rollout_batch_size (int): the number of parallel rollouts that should be used
exploit (boolean): whether or not to exploit, i.e. to act optimally according to the
current policy without any exploration
use_target_net (boolean): whether or not to use the target net for rollouts
compute_Q (boolean): whether or not to compute the Q values alongside the actions
noise_eps (float): scale of the additive Gaussian noise
random_eps (float): probability of selecting a completely random action
history_len (int): length of history for statistics smoothing
render (boolean): whether or not to render the rollouts
"""
self.envs = [make_env() for _ in range(rollout_batch_size)]
assert self.T > 0
self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')]
self.success_history = deque(maxlen=history_len)
self.Q_history = deque(maxlen=history_len)
self.n_episodes = 0
self.g = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals
self.initial_o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
self.reset_all_rollouts()
self.clear_history()
def reset_rollout(self, i):
"""Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o`
and `g` arrays accordingly.
"""
obs = self.envs[i].reset()
self.initial_o[i] = obs['observation']
self.initial_ag[i] = obs['achieved_goal']
self.g[i] = obs['desired_goal']
def reset_all_rollouts(self):
"""Resets all `rollout_batch_size` rollout workers.
"""
for i in range(self.rollout_batch_size):
self.reset_rollout(i)
def generate_rollouts(self):
"""Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current
policy acting on it accordingly.
"""
self.reset_all_rollouts()
# compute observations
o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
o[:] = self.initial_o
ag[:] = self.initial_ag
# generate episodes
obs, achieved_goals, acts, goals, successes = [], [], [], [], []
info_values = [np.empty((self.T, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
Qs = []
for t in range(self.T):
policy_output = self.policy.get_actions(
o, ag, self.g,
compute_Q=self.compute_Q,
noise_eps=self.noise_eps if not self.exploit else 0.,
random_eps=self.random_eps if not self.exploit else 0.,
use_target_net=self.use_target_net)
if self.compute_Q:
u, Q = policy_output
Qs.append(Q)
else:
u = policy_output
if u.ndim == 1:
# The non-batched case should still have a reasonable shape.
u = u.reshape(1, -1)
o_new = np.empty((self.rollout_batch_size, self.dims['o']))
ag_new = np.empty((self.rollout_batch_size, self.dims['g']))
success = np.zeros(self.rollout_batch_size)
# compute new states and observations
for i in range(self.rollout_batch_size):
try:
# We fully ignore the reward here because it will have to be re-computed
# for HER.
curr_o_new, _, _, info = self.envs[i].step(u[i])
if 'is_success' in info:
success[i] = info['is_success']
o_new[i] = curr_o_new['observation']
ag_new[i] = curr_o_new['achieved_goal']
for idx, key in enumerate(self.info_keys):
info_values[idx][t, i] = info[key]
if self.render:
self.envs[i].render()
except MujocoException as e:
return self.generate_rollouts()
if np.isnan(o_new).any():
self.logger.warning('NaN caught during rollout generation. Trying again...')
self.reset_all_rollouts()
return self.generate_rollouts()
obs.append(o.copy())
achieved_goals.append(ag.copy())
successes.append(success.copy())
acts.append(u.copy())
goals.append(self.g.copy())
o[...] = o_new
ag[...] = ag_new
obs.append(o.copy())
achieved_goals.append(ag.copy())
self.initial_o[:] = o
episode = dict(o=obs,
u=acts,
g=goals,
ag=achieved_goals)
for key, value in zip(self.info_keys, info_values):
episode['info_{}'.format(key)] = value
# stats
successful = np.array(successes)[-1, :]
assert successful.shape == (self.rollout_batch_size,)
success_rate = np.mean(successful)
self.success_history.append(success_rate)
if self.compute_Q:
self.Q_history.append(np.mean(Qs))
self.n_episodes += self.rollout_batch_size
return convert_episode_to_batch_major(episode)
def clear_history(self):
"""Clears all histories that are used for statistics
"""
self.success_history.clear()
self.Q_history.clear()
def current_success_rate(self):
return np.mean(self.success_history)
def current_mean_Q(self):
return np.mean(self.Q_history)
def save_policy(self, path):
"""Pickles the current policy for later inspection.
"""
with open(path, 'wb') as f:
pickle.dump(self.policy, f)
def logs(self, prefix='worker'):
"""Generates a dictionary that contains all collected statistics.
"""
logs = []
logs += [('success_rate', np.mean(self.success_history))]
if self.compute_Q:
logs += [('mean_Q', np.mean(self.Q_history))]
logs += [('episode', self.n_episodes)]
if prefix is not '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs]
else:
return logs
def seed(self, seed):
"""Seeds each environment with a distinct seed derived from the passed in global seed.
"""
for idx, env in enumerate(self.envs):
env.seed(seed + 1000 * idx)

144
baselines/her/util.py Normal file
View File

@@ -0,0 +1,144 @@
import os
import subprocess
import sys
import importlib
import inspect
import functools
import tensorflow as tf
import numpy as np
from baselines.common import tf_util as U
def store_args(method):
"""Stores provided method args as instance attributes.
"""
argspec = inspect.getfullargspec(method)
defaults = {}
if argspec.defaults is not None:
defaults = dict(
zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
if argspec.kwonlydefaults is not None:
defaults.update(argspec.kwonlydefaults)
arg_names = argspec.args[1:]
@functools.wraps(method)
def wrapper(*positional_args, **keyword_args):
self = positional_args[0]
# Get default arg values
args = defaults.copy()
# Add provided arg values
for name, value in zip(arg_names, positional_args[1:]):
args[name] = value
args.update(keyword_args)
self.__dict__.update(args)
return method(*positional_args, **keyword_args)
return wrapper
def import_function(spec):
"""Import a function identified by a string like "pkg.module:fn_name".
"""
mod_name, fn_name = spec.split(':')
module = importlib.import_module(mod_name)
fn = getattr(module, fn_name)
return fn
def flatten_grads(var_list, grads):
"""Flattens a variables and their gradients.
"""
return tf.concat([tf.reshape(grad, [U.numel(v)])
for (v, grad) in zip(var_list, grads)], 0)
def nn(input, layers_sizes, reuse=None, flatten=False, name=""):
"""Creates a simple neural network
"""
for i, size in enumerate(layers_sizes):
activation = tf.nn.relu if i < len(layers_sizes)-1 else None
input = tf.layers.dense(inputs=input,
units=size,
kernel_initializer=tf.contrib.layers.xavier_initializer(),
reuse=reuse,
name=name+'_'+str(i))
if activation:
input = activation(input)
if flatten:
assert layers_sizes[-1] == 1
input = tf.reshape(input, [-1])
return input
def install_mpi_excepthook():
import sys
from mpi4py import MPI
old_hook = sys.excepthook
def new_hook(a, b, c):
old_hook(a, b, c)
sys.stdout.flush()
sys.stderr.flush()
MPI.COMM_WORLD.Abort()
sys.excepthook = new_hook
def mpi_fork(n):
"""Re-launches the current script with workers
Returns "parent" for original parent, "child" for MPI children
"""
if n <= 1:
return "child"
if os.getenv("IN_MPI") is None:
env = os.environ.copy()
env.update(
MKL_NUM_THREADS="1",
OMP_NUM_THREADS="1",
IN_MPI="1"
)
# "-bind-to core" is crucial for good performance
args = [
"mpirun",
"-np",
str(n),
"-bind-to",
"core",
sys.executable
]
args += sys.argv
subprocess.check_call(args, env=env)
return "parent"
else:
install_mpi_excepthook()
return "child"
def convert_episode_to_batch_major(episode):
"""Converts an episode to have the batch dimension in the major (first)
dimension.
"""
episode_batch = {}
for key in episode.keys():
val = np.array(episode[key]).copy()
# make inputs batch-major instead of time-major
episode_batch[key] = val.swapaxes(0, 1)
return episode_batch
def transitions_in_episode_batch(episode_batch):
"""Number of transitions in a given episode batch.
"""
shape = episode_batch['u'].shape
return shape[0] * shape[1]
def reshape_for_broadcasting(source, target):
"""Reshapes a tensor (source) to have the correct shape and dtype of the target
before broadcasting it with MPI.
"""
dim = len(target.get_shape())
shape = ([1] * (dim-1)) + [-1]
return tf.reshape(tf.cast(source, target.dtype), shape)

View File

@@ -6,6 +6,7 @@ import json
import time
import datetime
import tempfile
from collections import defaultdict
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv']
# Also valid: json, tensorboard
@@ -124,7 +125,7 @@ class CSVOutputFormat(KVWriter):
if i > 0:
self.file.write(',')
v = kvs.get(k)
if v:
if v is not None:
self.file.write(str(v))
self.file.write('\n')
self.file.flush()
@@ -168,24 +169,18 @@ class TensorBoardOutputFormat(KVWriter):
self.writer.Close()
self.writer = None
def make_output_format(format, ev_dir):
from mpi4py import MPI
def make_output_format(format, ev_dir, log_suffix=''):
os.makedirs(ev_dir, exist_ok=True)
rank = MPI.COMM_WORLD.Get_rank()
if format == 'stdout':
return HumanOutputFormat(sys.stdout)
elif format == 'log':
suffix = "" if rank==0 else ("-mpi%03i"%rank)
return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % suffix))
return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
elif format == 'json':
assert rank==0
return JSONOutputFormat(osp.join(ev_dir, 'progress.json'))
return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
elif format == 'csv':
assert rank==0
return CSVOutputFormat(osp.join(ev_dir, 'progress.csv'))
return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
elif format == 'tensorboard':
assert rank==0
return TensorBoardOutputFormat(osp.join(ev_dir, 'tb'))
return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
else:
raise ValueError('Unknown format specified: %s' % (format,))
@@ -197,9 +192,16 @@ def logkv(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
"""
Logger.CURRENT.logkv(key, val)
def logkv_mean(key, val):
"""
The same as logkv(), but if called many times, values averaged.
"""
Logger.CURRENT.logkv_mean(key, val)
def logkvs(d):
"""
Log a dictionary of key-value pairs
@@ -255,6 +257,33 @@ def get_dir():
record_tabular = logkv
dump_tabular = dumpkvs
class ProfileKV:
"""
Usage:
with logger.ProfileKV("interesting_scope"):
code
"""
def __init__(self, n):
self.n = "wait_" + n
def __enter__(self):
self.t1 = time.time()
def __exit__(self ,type, value, traceback):
Logger.CURRENT.name2val[self.n] += time.time() - self.t1
def profile(n):
"""
Usage:
@profile("my_func")
def my_func(): code
"""
def decorator_with_name(func):
def func_wrapper(*args, **kwargs):
with ProfileKV(n):
return func(*args, **kwargs)
return func_wrapper
return decorator_with_name
# ================================================================
# Backend
# ================================================================
@@ -265,7 +294,8 @@ class Logger(object):
CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir, output_formats):
self.name2val = {} # values this iteration
self.name2val = defaultdict(float) # values this iteration
self.name2cnt = defaultdict(int)
self.level = INFO
self.dir = dir
self.output_formats = output_formats
@@ -275,12 +305,21 @@ class Logger(object):
def logkv(self, key, val):
self.name2val[key] = val
def logkv_mean(self, key, val):
if val is None:
self.name2val[key] = None
return
oldval, cnt = self.name2val[key], self.name2cnt[key]
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
self.name2cnt[key] = cnt + 1
def dumpkvs(self):
if self.level == DISABLED: return
for fmt in self.output_formats:
if isinstance(fmt, KVWriter):
fmt.writekvs(self.name2val)
self.name2val.clear()
self.name2cnt.clear()
def log(self, *args, level=INFO):
if self.level <= level:
@@ -360,6 +399,11 @@ def _demo():
logkv("a", 5.5)
dumpkvs()
info("^^^ should see a = 5.5")
logkv_mean("b", -22.5)
logkv_mean("b", -44.4)
logkv("a", 5.5)
dumpkvs()
info("^^^ should see b = 33.3")
logkv("b", -2.5)
dumpkvs()

View File

@@ -10,7 +10,7 @@ setup(name='baselines',
packages=[package for package in find_packages()
if package.startswith('baselines')],
install_requires=[
'gym[mujoco,atari,classic_control]',
'gym[mujoco,atari,classic_control,robotics]',
'scipy',
'tqdm',
'joblib',
@@ -19,9 +19,11 @@ setup(name='baselines',
'progressbar2',
'mpi4py',
'cloudpickle',
'tensorflow>=1.4.0',
'click',
],
description='OpenAI baselines: high quality implementations of reinforcement learning algorithms',
author='OpenAI',
url='https://github.com/openai/baselines',
author_email='gym@openai.com',
version='0.1.4')
version='0.1.5')