Compare commits
2 Commits
master
...
matthias-h
Author | SHA1 | Date | |
---|---|---|---|
|
d90638b565 | ||
|
f4953c3c2d |
@@ -20,6 +20,7 @@ pip install -e .
|
||||
- [DDPG](baselines/ddpg)
|
||||
- [DQN](baselines/deepq)
|
||||
- [GAIL](baselines/gail)
|
||||
- [HER](baselines/her)
|
||||
- [PPO1](baselines/ppo1) (Multi-CPU using MPI)
|
||||
- [PPO2](baselines/ppo2) (Optimized for GPU)
|
||||
- [TRPO](baselines/trpo_mpi)
|
||||
|
@@ -39,12 +39,24 @@ def ortho_init(scale=1.0):
|
||||
return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
|
||||
return _ortho_init
|
||||
|
||||
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0):
|
||||
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC'):
|
||||
if data_format == 'NHWC':
|
||||
channel_ax = 3
|
||||
strides = [1, stride, stride, 1]
|
||||
bshape = [1, 1, 1, nf]
|
||||
elif data_format == 'NCHW':
|
||||
channel_ax = 1
|
||||
strides = [1, 1, stride, stride]
|
||||
bshape = [1, nf, 1, 1]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
nin = x.get_shape()[channel_ax].value
|
||||
wshape = [rf, rf, nin, nf]
|
||||
with tf.variable_scope(scope):
|
||||
nin = x.get_shape()[3].value
|
||||
w = tf.get_variable("w", [rf, rf, nin, nf], initializer=ortho_init(init_scale))
|
||||
b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0.0))
|
||||
return tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b
|
||||
w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
|
||||
b = tf.get_variable("b", [1, nf, 1, 1], initializer=tf.constant_initializer(0.0))
|
||||
if data_format == 'NHWC': b = tf.reshape(b, bshape)
|
||||
return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format)
|
||||
|
||||
def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
|
||||
with tf.variable_scope(scope):
|
||||
|
@@ -7,12 +7,13 @@ from glob import glob
|
||||
import csv
|
||||
import os.path as osp
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
class Monitor(Wrapper):
|
||||
EXT = "monitor.csv"
|
||||
f = None
|
||||
|
||||
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=()):
|
||||
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
|
||||
Wrapper.__init__(self, env=env)
|
||||
self.tstart = time.time()
|
||||
if filename is None:
|
||||
@@ -26,10 +27,12 @@ class Monitor(Wrapper):
|
||||
filename = filename + "." + Monitor.EXT
|
||||
self.f = open(filename, "wt")
|
||||
self.f.write('#%s\n'%json.dumps({"t_start": self.tstart, 'env_id' : env.spec and env.spec.id}))
|
||||
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords)
|
||||
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords+info_keywords)
|
||||
self.logger.writeheader()
|
||||
self.f.flush()
|
||||
|
||||
self.reset_keywords = reset_keywords
|
||||
self.info_keywords = info_keywords
|
||||
self.allow_early_resets = allow_early_resets
|
||||
self.rewards = None
|
||||
self.needs_reset = True
|
||||
@@ -61,6 +64,8 @@ class Monitor(Wrapper):
|
||||
eprew = sum(self.rewards)
|
||||
eplen = len(self.rewards)
|
||||
epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)}
|
||||
for k in self.info_keywords:
|
||||
epinfo[k] = info[k]
|
||||
self.episode_rewards.append(eprew)
|
||||
self.episode_lengths.append(eplen)
|
||||
self.episode_times.append(time.time() - self.tstart)
|
||||
|
@@ -4,6 +4,7 @@ Helpers for scripts like run_atari.py.
|
||||
|
||||
import os
|
||||
import gym
|
||||
from gym.wrappers import FlattenDictWrapper
|
||||
from baselines import logger
|
||||
from baselines.bench import Monitor
|
||||
from baselines.common import set_global_seeds
|
||||
@@ -36,6 +37,19 @@ def make_mujoco_env(env_id, seed):
|
||||
env.seed(seed)
|
||||
return env
|
||||
|
||||
def make_robotics_env(env_id, seed, rank=0):
|
||||
"""
|
||||
Create a wrapped, monitored gym.Env for MuJoCo.
|
||||
"""
|
||||
set_global_seeds(seed)
|
||||
env = gym.make(env_id)
|
||||
env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
|
||||
env = Monitor(
|
||||
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
|
||||
info_keywords=('is_success',))
|
||||
env.seed(seed)
|
||||
return env
|
||||
|
||||
def arg_parser():
|
||||
"""
|
||||
Create an empty argparse.ArgumentParser.
|
||||
@@ -58,7 +72,17 @@ def mujoco_arg_parser():
|
||||
Create an argparse.ArgumentParser for run_mujoco.py.
|
||||
"""
|
||||
parser = arg_parser()
|
||||
parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1")
|
||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
|
||||
return parser
|
||||
|
||||
def robotics_arg_parser():
|
||||
"""
|
||||
Create an argparse.ArgumentParser for run_mujoco.py.
|
||||
"""
|
||||
parser = arg_parser()
|
||||
parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0')
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
|
||||
return parser
|
||||
|
@@ -16,7 +16,12 @@ def fmt_item(x, l):
|
||||
if isinstance(x, np.ndarray):
|
||||
assert x.ndim==0
|
||||
x = x.item()
|
||||
if isinstance(x, float): rep = "%g"%x
|
||||
if isinstance(x, (float, np.float32, np.float64)):
|
||||
v = abs(x)
|
||||
if (v < 1e-4 or v > 1e+4) and v > 0:
|
||||
rep = "%7.2e" % x
|
||||
else:
|
||||
rep = "%7.5f" % x
|
||||
else: rep = str(x)
|
||||
return " "*(l - len(rep)) + rep
|
||||
|
||||
|
@@ -261,3 +261,20 @@ def get_placeholder_cached(name):
|
||||
|
||||
def flattenallbut0(x):
|
||||
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Diagnostics
|
||||
# ================================================================
|
||||
|
||||
def display_var_info(vars):
|
||||
from baselines import logger
|
||||
count_params = 0
|
||||
for v in vars:
|
||||
name = v.name
|
||||
if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
|
||||
count_params += np.prod(v.shape.as_list())
|
||||
if "/b:" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
|
||||
logger.info(" %s%s%s" % (name, " "*(55-len(name)), str(v.shape)))
|
||||
logger.info("Total model parameters: %0.1f million" % (count_params*1e-6))
|
||||
|
||||
|
@@ -1,31 +1,42 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from . import VecEnv
|
||||
|
||||
class DummyVecEnv(VecEnv):
|
||||
def __init__(self, env_fns):
|
||||
self.envs = [fn() for fn in env_fns]
|
||||
env = self.envs[0]
|
||||
env = self.envs[0]
|
||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||
self.ts = np.zeros(len(self.envs), dtype='int')
|
||||
|
||||
obs_spaces = self.observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (self.observation_space,)
|
||||
self.buf_obs = [np.zeros((self.num_envs,) + tuple(s.shape), s.dtype) for s in obs_spaces]
|
||||
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
|
||||
def step_async(self, actions):
|
||||
self.actions = actions
|
||||
|
||||
def step_wait(self):
|
||||
results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
|
||||
obs, rews, dones, infos = map(np.array, zip(*results))
|
||||
self.ts += 1
|
||||
for (i, done) in enumerate(dones):
|
||||
if done:
|
||||
obs[i] = self.envs[i].reset()
|
||||
self.ts[i] = 0
|
||||
self.actions = None
|
||||
return np.array(obs), np.array(rews), np.array(dones), infos
|
||||
for i in range(self.num_envs):
|
||||
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i])
|
||||
if isinstance(obs_tuple, (tuple, list)):
|
||||
for t,x in enumerate(obs_tuple):
|
||||
self.buf_obs[t][i] = x
|
||||
else:
|
||||
self.buf_obs[0][i] = obs_tuple
|
||||
return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos
|
||||
|
||||
def reset(self):
|
||||
results = [env.reset() for env in self.envs]
|
||||
return np.array(results)
|
||||
for i in range(self.num_envs):
|
||||
obs_tuple = self.envs[i].reset()
|
||||
if isinstance(obs_tuple, (tuple, list)):
|
||||
for t,x in enumerate(obs_tuple):
|
||||
self.buf_obs[t][i] = x
|
||||
else:
|
||||
self.buf_obs[0][i] = obs_tuple
|
||||
return self.buf_obs
|
||||
|
||||
def close(self):
|
||||
return
|
35
baselines/her/README.md
Normal file
35
baselines/her/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Hindsight Experience Replay
|
||||
For details on Hindsight Experience Replay (HER), please read the [paper](https://arxiv.org/pdf/1707.01495.pdf).
|
||||
|
||||
## How to use Hindsight Experience Replay
|
||||
|
||||
### Getting started
|
||||
Training an agent is very simple:
|
||||
```bash
|
||||
python -m baselines.her.experiment.train
|
||||
```
|
||||
This will train a DDPG+HER agent on the `FetchReach` environment.
|
||||
You should see the success rate go up quickly to `1.0`, which means that the agent achieves the
|
||||
desired goal in 100% of the cases.
|
||||
The training script logs other diagnostics as well and pickles the best policy so far (w.r.t. to its test success rate),
|
||||
the latest policy, and, if enabled, a history of policies every K epochs.
|
||||
|
||||
To inspect what the agent has learned, use the play script:
|
||||
```bash
|
||||
python -m baselines.her.experiment.play /path/to/an/experiment/policy_best.pkl
|
||||
```
|
||||
You can try it right now with the results of the training step (the script prints out the path for you).
|
||||
This should visualize the current policy for 10 episodes and will also print statistics.
|
||||
|
||||
|
||||
### Advanced usage
|
||||
The train script comes with advanced features like MPI support, that allows to scale across all cores of a single machine.
|
||||
To see all available options, simply run this command:
|
||||
```bash
|
||||
python -m baselines.her.experiment.train --help
|
||||
```
|
||||
To run on, say, 20 CPU cores, you can use the following command:
|
||||
```bash
|
||||
python -m baselines.her.experiment.train --num_cpu 20
|
||||
```
|
||||
That's it, you are now running rollouts using 20 MPI workers and average gradients for network updates across all 20 core.
|
0
baselines/her/__init__.py
Normal file
0
baselines/her/__init__.py
Normal file
44
baselines/her/actor_critic.py
Normal file
44
baselines/her/actor_critic.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import tensorflow as tf
|
||||
from baselines.her.util import store_args, nn
|
||||
|
||||
|
||||
class ActorCritic:
|
||||
@store_args
|
||||
def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers,
|
||||
**kwargs):
|
||||
"""The actor-critic network and related training code.
|
||||
|
||||
Args:
|
||||
inputs_tf (dict of tensors): all necessary inputs for the network: the
|
||||
observation (o), the goal (g), and the action (u)
|
||||
dimo (int): the dimension of the observations
|
||||
dimg (int): the dimension of the goals
|
||||
dimu (int): the dimension of the actions
|
||||
max_u (float): the maximum magnitude of actions; action outputs will be scaled
|
||||
accordingly
|
||||
o_stats (baselines.her.Normalizer): normalizer for observations
|
||||
g_stats (baselines.her.Normalizer): normalizer for goals
|
||||
hidden (int): number of hidden units that should be used in hidden layers
|
||||
layers (int): number of hidden layers
|
||||
"""
|
||||
self.o_tf = inputs_tf['o']
|
||||
self.g_tf = inputs_tf['g']
|
||||
self.u_tf = inputs_tf['u']
|
||||
|
||||
# Prepare inputs for actor and critic.
|
||||
o = self.o_stats.normalize(self.o_tf)
|
||||
g = self.g_stats.normalize(self.g_tf)
|
||||
input_pi = tf.concat(axis=1, values=[o, g]) # for actor
|
||||
|
||||
# Networks.
|
||||
with tf.variable_scope('pi'):
|
||||
self.pi_tf = self.max_u * tf.tanh(nn(
|
||||
input_pi, [self.hidden] * self.layers + [self.dimu]))
|
||||
with tf.variable_scope('Q'):
|
||||
# for policy training
|
||||
input_Q = tf.concat(axis=1, values=[o, g, self.pi_tf / self.max_u])
|
||||
self.Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1])
|
||||
# for critic training
|
||||
input_Q = tf.concat(axis=1, values=[o, g, self.u_tf / self.max_u])
|
||||
self._input_Q = input_Q # exposed for tests
|
||||
self.Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True)
|
340
baselines/her/ddpg.py
Normal file
340
baselines/her/ddpg.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.staging import StagingArea
|
||||
|
||||
from baselines import logger
|
||||
from baselines.her.util import (
|
||||
import_function, store_args, flatten_grads, transitions_in_episode_batch)
|
||||
from baselines.her.normalizer import Normalizer
|
||||
from baselines.her.replay_buffer import ReplayBuffer
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
|
||||
|
||||
def dims_to_shapes(input_dims):
|
||||
return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()}
|
||||
|
||||
|
||||
class DDPG(object):
|
||||
@store_args
|
||||
def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
|
||||
Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
|
||||
rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
|
||||
sample_transitions, gamma, reuse=False, **kwargs):
|
||||
"""Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
|
||||
|
||||
Args:
|
||||
input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
|
||||
actions (u)
|
||||
buffer_size (int): number of transitions that are stored in the replay buffer
|
||||
hidden (int): number of units in the hidden layers
|
||||
layers (int): number of hidden layers
|
||||
network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
|
||||
polyak (float): coefficient for Polyak-averaging of the target network
|
||||
batch_size (int): batch size for training
|
||||
Q_lr (float): learning rate for the Q (critic) network
|
||||
pi_lr (float): learning rate for the pi (actor) network
|
||||
norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
|
||||
norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
|
||||
max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
|
||||
action_l2 (float): coefficient for L2 penalty on the actions
|
||||
clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
|
||||
scope (str): the scope used for the TensorFlow graph
|
||||
T (int): the time horizon for rollouts
|
||||
rollout_batch_size (int): number of parallel rollouts per DDPG agent
|
||||
subtract_goals (function): function that subtracts goals from each other
|
||||
relative_goals (boolean): whether or not relative goals should be fed into the network
|
||||
clip_pos_returns (boolean): whether or not positive returns should be clipped
|
||||
clip_return (float): clip returns to be in [-clip_return, clip_return]
|
||||
sample_transitions (function) function that samples from the replay buffer
|
||||
gamma (float): gamma used for Q learning updates
|
||||
reuse (boolean): whether or not the networks should be reused
|
||||
"""
|
||||
if self.clip_return is None:
|
||||
self.clip_return = np.inf
|
||||
|
||||
self.create_actor_critic = import_function(self.network_class)
|
||||
|
||||
input_shapes = dims_to_shapes(self.input_dims)
|
||||
self.dimo = self.input_dims['o']
|
||||
self.dimg = self.input_dims['g']
|
||||
self.dimu = self.input_dims['u']
|
||||
|
||||
# Prepare staging area for feeding data to the model.
|
||||
stage_shapes = OrderedDict()
|
||||
for key in sorted(self.input_dims.keys()):
|
||||
if key.startswith('info_'):
|
||||
continue
|
||||
stage_shapes[key] = (None, *input_shapes[key])
|
||||
for key in ['o', 'g']:
|
||||
stage_shapes[key + '_2'] = stage_shapes[key]
|
||||
stage_shapes['r'] = (None,)
|
||||
self.stage_shapes = stage_shapes
|
||||
|
||||
# Create network.
|
||||
with tf.variable_scope(self.scope):
|
||||
self.staging_tf = StagingArea(
|
||||
dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
|
||||
shapes=list(self.stage_shapes.values()))
|
||||
self.buffer_ph_tf = [
|
||||
tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
|
||||
self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
|
||||
|
||||
self._create_network(reuse=reuse)
|
||||
|
||||
# Configure the replay buffer.
|
||||
buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
|
||||
for key, val in input_shapes.items()}
|
||||
buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
|
||||
buffer_shapes['ag'] = (self.T+1, self.dimg)
|
||||
|
||||
buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
|
||||
self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
|
||||
|
||||
def _random_action(self, n):
|
||||
return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))
|
||||
|
||||
def _preprocess_og(self, o, ag, g):
|
||||
if self.relative_goals:
|
||||
g_shape = g.shape
|
||||
g = g.reshape(-1, self.dimg)
|
||||
ag = ag.reshape(-1, self.dimg)
|
||||
g = self.subtract_goals(g, ag)
|
||||
g = g.reshape(*g_shape)
|
||||
o = np.clip(o, -self.clip_obs, self.clip_obs)
|
||||
g = np.clip(g, -self.clip_obs, self.clip_obs)
|
||||
return o, g
|
||||
|
||||
def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
|
||||
compute_Q=False):
|
||||
o, g = self._preprocess_og(o, ag, g)
|
||||
policy = self.target if use_target_net else self.main
|
||||
# values to compute
|
||||
vals = [policy.pi_tf]
|
||||
if compute_Q:
|
||||
vals += [policy.Q_pi_tf]
|
||||
# feed
|
||||
feed = {
|
||||
policy.o_tf: o.reshape(-1, self.dimo),
|
||||
policy.g_tf: g.reshape(-1, self.dimg),
|
||||
policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
|
||||
}
|
||||
|
||||
ret = self.sess.run(vals, feed_dict=feed)
|
||||
# action postprocessing
|
||||
u = ret[0]
|
||||
noise = noise_eps * self.max_u * np.random.randn(*u.shape) # gaussian noise
|
||||
u += noise
|
||||
u = np.clip(u, -self.max_u, self.max_u)
|
||||
u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u) # eps-greedy
|
||||
if u.shape[0] == 1:
|
||||
u = u[0]
|
||||
u = u.copy()
|
||||
ret[0] = u
|
||||
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return ret
|
||||
|
||||
def store_episode(self, episode_batch, update_stats=True):
|
||||
"""
|
||||
episode_batch: array of batch_size x (T or T+1) x dim_key
|
||||
'o' is of size T+1, others are of size T
|
||||
"""
|
||||
|
||||
self.buffer.store_episode(episode_batch)
|
||||
|
||||
if update_stats:
|
||||
# add transitions to normalizer
|
||||
episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
|
||||
episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
|
||||
num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
|
||||
transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)
|
||||
|
||||
o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
# No need to preprocess the o_2 and g_2 since this is only used for stats
|
||||
|
||||
self.o_stats.update(transitions['o'])
|
||||
self.g_stats.update(transitions['g'])
|
||||
|
||||
self.o_stats.recompute_stats()
|
||||
self.g_stats.recompute_stats()
|
||||
|
||||
def get_current_buffer_size(self):
|
||||
return self.buffer.get_current_size()
|
||||
|
||||
def _sync_optimizers(self):
|
||||
self.Q_adam.sync()
|
||||
self.pi_adam.sync()
|
||||
|
||||
def _grads(self):
|
||||
# Avoid feed_dict here for performance!
|
||||
critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
|
||||
self.Q_loss_tf,
|
||||
self.main.Q_pi_tf,
|
||||
self.Q_grad_tf,
|
||||
self.pi_grad_tf
|
||||
])
|
||||
return critic_loss, actor_loss, Q_grad, pi_grad
|
||||
|
||||
def _update(self, Q_grad, pi_grad):
|
||||
self.Q_adam.update(Q_grad, self.Q_lr)
|
||||
self.pi_adam.update(pi_grad, self.pi_lr)
|
||||
|
||||
def sample_batch(self):
|
||||
transitions = self.buffer.sample(self.batch_size)
|
||||
o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
|
||||
ag, ag_2 = transitions['ag'], transitions['ag_2']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)
|
||||
|
||||
transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
|
||||
return transitions_batch
|
||||
|
||||
def stage_batch(self, batch=None):
|
||||
if batch is None:
|
||||
batch = self.sample_batch()
|
||||
assert len(self.buffer_ph_tf) == len(batch)
|
||||
self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))
|
||||
|
||||
def train(self, stage=True):
|
||||
if stage:
|
||||
self.stage_batch()
|
||||
critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
|
||||
self._update(Q_grad, pi_grad)
|
||||
return critic_loss, actor_loss
|
||||
|
||||
def _init_target_net(self):
|
||||
self.sess.run(self.init_target_net_op)
|
||||
|
||||
def update_target_net(self):
|
||||
self.sess.run(self.update_target_net_op)
|
||||
|
||||
def clear_buffer(self):
|
||||
self.buffer.clear_buffer()
|
||||
|
||||
def _vars(self, scope):
|
||||
res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
|
||||
assert len(res) > 0
|
||||
return res
|
||||
|
||||
def _global_vars(self, scope):
|
||||
res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
|
||||
return res
|
||||
|
||||
def _create_network(self, reuse=False):
|
||||
logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))
|
||||
|
||||
self.sess = tf.get_default_session()
|
||||
if self.sess is None:
|
||||
self.sess = tf.InteractiveSession()
|
||||
|
||||
# running averages
|
||||
with tf.variable_scope('o_stats') as vs:
|
||||
if reuse:
|
||||
vs.reuse_variables()
|
||||
self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
|
||||
with tf.variable_scope('g_stats') as vs:
|
||||
if reuse:
|
||||
vs.reuse_variables()
|
||||
self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)
|
||||
|
||||
# mini-batch sampling.
|
||||
batch = self.staging_tf.get()
|
||||
batch_tf = OrderedDict([(key, batch[i])
|
||||
for i, key in enumerate(self.stage_shapes.keys())])
|
||||
batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
|
||||
|
||||
# networks
|
||||
with tf.variable_scope('main') as vs:
|
||||
if reuse:
|
||||
vs.reuse_variables()
|
||||
self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
|
||||
vs.reuse_variables()
|
||||
with tf.variable_scope('target') as vs:
|
||||
if reuse:
|
||||
vs.reuse_variables()
|
||||
target_batch_tf = batch_tf.copy()
|
||||
target_batch_tf['o'] = batch_tf['o_2']
|
||||
target_batch_tf['g'] = batch_tf['g_2']
|
||||
self.target = self.create_actor_critic(
|
||||
target_batch_tf, net_type='target', **self.__dict__)
|
||||
vs.reuse_variables()
|
||||
assert len(self._vars("main")) == len(self._vars("target"))
|
||||
|
||||
# loss functions
|
||||
target_Q_pi_tf = self.target.Q_pi_tf
|
||||
clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
|
||||
target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
|
||||
self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
|
||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
|
||||
pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
|
||||
assert len(self._vars('main/Q')) == len(Q_grads_tf)
|
||||
assert len(self._vars('main/pi')) == len(pi_grads_tf)
|
||||
self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
|
||||
self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
|
||||
self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
|
||||
self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))
|
||||
|
||||
# optimizers
|
||||
self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
|
||||
self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)
|
||||
|
||||
# polyak averaging
|
||||
self.main_vars = self._vars('main/Q') + self._vars('main/pi')
|
||||
self.target_vars = self._vars('target/Q') + self._vars('target/pi')
|
||||
self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
|
||||
self.init_target_net_op = list(
|
||||
map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
|
||||
self.update_target_net_op = list(
|
||||
map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))
|
||||
|
||||
# initialize all variables
|
||||
tf.variables_initializer(self._global_vars('')).run()
|
||||
self._sync_optimizers()
|
||||
self._init_target_net()
|
||||
|
||||
def logs(self, prefix=''):
|
||||
logs = []
|
||||
logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
|
||||
logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
|
||||
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
|
||||
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
|
||||
|
||||
if prefix is not '' and not prefix.endswith('/'):
|
||||
return [(prefix + '/' + key, val) for key, val in logs]
|
||||
else:
|
||||
return logs
|
||||
|
||||
def __getstate__(self):
|
||||
"""Our policies can be loaded from pkl, but after unpickling you cannot continue training.
|
||||
"""
|
||||
excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
|
||||
'main', 'target', 'lock', 'env', 'sample_transitions',
|
||||
'stage_shapes', 'create_actor_critic']
|
||||
|
||||
state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
|
||||
state['buffer_size'] = self.buffer_size
|
||||
state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name])
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
if 'sample_transitions' not in state:
|
||||
# We don't need this for playing the policy.
|
||||
state['sample_transitions'] = None
|
||||
|
||||
self.__init__(**state)
|
||||
# set up stats (they are overwritten in __init__)
|
||||
for k, v in state.items():
|
||||
if k[-6:] == '_stats':
|
||||
self.__dict__[k] = v
|
||||
# load TF variables
|
||||
vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
|
||||
assert(len(vars) == len(state["tf"]))
|
||||
node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
|
||||
self.sess.run(node)
|
0
baselines/her/experiment/__init__.py
Normal file
0
baselines/her/experiment/__init__.py
Normal file
170
baselines/her/experiment/config.py
Normal file
170
baselines/her/experiment/config.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import gym
|
||||
|
||||
from baselines import logger
|
||||
from baselines.her.ddpg import DDPG
|
||||
from baselines.her.her import make_sample_her_transitions
|
||||
|
||||
|
||||
DEFAULT_ENV_PARAMS = {
|
||||
'FetchReach-v0': {
|
||||
'n_cycles': 10,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_PARAMS = {
|
||||
# env
|
||||
'max_u': 1., # max absolute value of actions on different coordinates
|
||||
# ddpg
|
||||
'layers': 3, # number of layers in the critic/actor networks
|
||||
'hidden': 256, # number of neurons in each hidden layers
|
||||
'network_class': 'baselines.her.actor_critic:ActorCritic',
|
||||
'Q_lr': 0.001, # critic learning rate
|
||||
'pi_lr': 0.001, # actor learning rate
|
||||
'buffer_size': int(1E6), # for experience replay
|
||||
'polyak': 0.95, # polyak averaging coefficient
|
||||
'action_l2': 1.0, # quadratic penalty on actions (before rescaling by max_u)
|
||||
'clip_obs': 200.,
|
||||
'scope': 'ddpg', # can be tweaked for testing
|
||||
'relative_goals': False,
|
||||
# training
|
||||
'n_cycles': 50, # per epoch
|
||||
'rollout_batch_size': 2, # per mpi thread
|
||||
'n_batches': 40, # training batches per cycle
|
||||
'batch_size': 256, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length.
|
||||
'n_test_rollouts': 10, # number of test rollouts per epoch, each consists of rollout_batch_size rollouts
|
||||
'test_with_polyak': False, # run test episodes with the target network
|
||||
# exploration
|
||||
'random_eps': 0.3, # percentage of time a random action is taken
|
||||
'noise_eps': 0.2, # std of gaussian noise added to not-completely-random actions as a percentage of max_u
|
||||
# HER
|
||||
'replay_strategy': 'future', # supported modes: future, none
|
||||
'replay_k': 4, # number of additional goals used for replay, only used if off_policy_data=future
|
||||
# normalization
|
||||
'norm_eps': 0.01, # epsilon used for observation normalization
|
||||
'norm_clip': 5, # normalized observations are cropped to this values
|
||||
}
|
||||
|
||||
|
||||
CACHED_ENVS = {}
|
||||
def cached_make_env(make_env):
|
||||
"""
|
||||
Only creates a new environment from the provided function if one has not yet already been
|
||||
created. This is useful here because we need to infer certain properties of the env, e.g.
|
||||
its observation and action spaces, without any intend of actually using it.
|
||||
"""
|
||||
if make_env not in CACHED_ENVS:
|
||||
env = make_env()
|
||||
CACHED_ENVS[make_env] = env
|
||||
return CACHED_ENVS[make_env]
|
||||
|
||||
|
||||
def prepare_params(kwargs):
|
||||
# DDPG params
|
||||
ddpg_params = dict()
|
||||
|
||||
env_name = kwargs['env_name']
|
||||
def make_env():
|
||||
return gym.make(env_name)
|
||||
kwargs['make_env'] = make_env
|
||||
tmp_env = cached_make_env(kwargs['make_env'])
|
||||
assert hasattr(tmp_env, '_max_episode_steps')
|
||||
kwargs['T'] = tmp_env._max_episode_steps
|
||||
tmp_env.reset()
|
||||
kwargs['max_u'] = np.array(kwargs['max_u']) if type(kwargs['max_u']) == list else kwargs['max_u']
|
||||
kwargs['gamma'] = 1. - 1. / kwargs['T']
|
||||
if 'lr' in kwargs:
|
||||
kwargs['pi_lr'] = kwargs['lr']
|
||||
kwargs['Q_lr'] = kwargs['lr']
|
||||
del kwargs['lr']
|
||||
for name in ['buffer_size', 'hidden', 'layers',
|
||||
'network_class',
|
||||
'polyak',
|
||||
'batch_size', 'Q_lr', 'pi_lr',
|
||||
'norm_eps', 'norm_clip', 'max_u',
|
||||
'action_l2', 'clip_obs', 'scope', 'relative_goals']:
|
||||
ddpg_params[name] = kwargs[name]
|
||||
kwargs['_' + name] = kwargs[name]
|
||||
del kwargs[name]
|
||||
kwargs['ddpg_params'] = ddpg_params
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def log_params(params, logger=logger):
|
||||
for key in sorted(params.keys()):
|
||||
logger.info('{}: {}'.format(key, params[key]))
|
||||
|
||||
|
||||
def configure_her(params):
|
||||
env = cached_make_env(params['make_env'])
|
||||
env.reset()
|
||||
def reward_fun(ag_2, g, info): # vectorized
|
||||
return env.compute_reward(achieved_goal=ag_2, desired_goal=g, info=info)
|
||||
|
||||
# Prepare configuration for HER.
|
||||
her_params = {
|
||||
'reward_fun': reward_fun,
|
||||
}
|
||||
for name in ['replay_strategy', 'replay_k']:
|
||||
her_params[name] = params[name]
|
||||
params['_' + name] = her_params[name]
|
||||
del params[name]
|
||||
sample_her_transitions = make_sample_her_transitions(**her_params)
|
||||
|
||||
return sample_her_transitions
|
||||
|
||||
|
||||
def simple_goal_subtract(a, b):
|
||||
assert a.shape == b.shape
|
||||
return a - b
|
||||
|
||||
|
||||
def configure_ddpg(dims, params, reuse=False, use_mpi=True, clip_return=True):
|
||||
sample_her_transitions = configure_her(params)
|
||||
# Extract relevant parameters.
|
||||
gamma = params['gamma']
|
||||
rollout_batch_size = params['rollout_batch_size']
|
||||
ddpg_params = params['ddpg_params']
|
||||
|
||||
input_dims = dims.copy()
|
||||
|
||||
# DDPG agent
|
||||
env = cached_make_env(params['make_env'])
|
||||
env.reset()
|
||||
ddpg_params.update({'input_dims': input_dims, # agent takes an input observations
|
||||
'T': params['T'],
|
||||
'clip_pos_returns': True, # clip positive returns
|
||||
'clip_return': (1. / (1. - gamma)) if clip_return else np.inf, # max abs of return
|
||||
'rollout_batch_size': rollout_batch_size,
|
||||
'subtract_goals': simple_goal_subtract,
|
||||
'sample_transitions': sample_her_transitions,
|
||||
'gamma': gamma,
|
||||
})
|
||||
ddpg_params['info'] = {
|
||||
'env_name': params['env_name'],
|
||||
}
|
||||
policy = DDPG(reuse=reuse, **ddpg_params, use_mpi=use_mpi)
|
||||
return policy
|
||||
|
||||
|
||||
def configure_dims(params):
|
||||
env = cached_make_env(params['make_env'])
|
||||
env.reset()
|
||||
obs, _, _, info = env.step(env.action_space.sample())
|
||||
|
||||
dims = {
|
||||
'o': obs['observation'].shape[0],
|
||||
'u': env.action_space.shape[0],
|
||||
'g': obs['desired_goal'].shape[0],
|
||||
}
|
||||
for key, value in info.items():
|
||||
value = np.array(value)
|
||||
if value.ndim == 0:
|
||||
value = value.reshape(1)
|
||||
dims['info_{}'.format(key)] = value.shape[0]
|
||||
return dims
|
60
baselines/her/experiment/play.py
Normal file
60
baselines/her/experiment/play.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import click
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds
|
||||
import baselines.her.experiment.config as config
|
||||
from baselines.her.rollout import RolloutWorker
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument('policy_file', type=str)
|
||||
@click.option('--seed', type=int, default=0)
|
||||
@click.option('--n_test_rollouts', type=int, default=10)
|
||||
@click.option('--render', type=int, default=1)
|
||||
def main(policy_file, seed, n_test_rollouts, render):
|
||||
set_global_seeds(seed)
|
||||
|
||||
# Load policy.
|
||||
with open(policy_file, 'rb') as f:
|
||||
policy = pickle.load(f)
|
||||
env_name = policy.info['env_name']
|
||||
|
||||
# Prepare params.
|
||||
params = config.DEFAULT_PARAMS
|
||||
if env_name in config.DEFAULT_ENV_PARAMS:
|
||||
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
|
||||
params['env_name'] = env_name
|
||||
params = config.prepare_params(params)
|
||||
config.log_params(params, logger=logger)
|
||||
|
||||
dims = config.configure_dims(params)
|
||||
|
||||
eval_params = {
|
||||
'exploit': True,
|
||||
'use_target_net': params['test_with_polyak'],
|
||||
'compute_Q': True,
|
||||
'rollout_batch_size': 1,
|
||||
'render': bool(render),
|
||||
}
|
||||
|
||||
for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
|
||||
eval_params[name] = params[name]
|
||||
|
||||
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
|
||||
evaluator.seed(seed)
|
||||
|
||||
# Run evaluation.
|
||||
evaluator.clear_history()
|
||||
for _ in range(n_test_rollouts):
|
||||
evaluator.generate_rollouts()
|
||||
|
||||
# record logs
|
||||
for key, val in evaluator.logs('test'):
|
||||
logger.record_tabular(key, np.mean(val))
|
||||
logger.dump_tabular()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
118
baselines/her/experiment/plot.py
Normal file
118
baselines/her/experiment/plot.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import json
|
||||
import seaborn as sns; sns.set()
|
||||
import glob2
|
||||
import argparse
|
||||
|
||||
|
||||
def smooth_reward_curve(x, y):
|
||||
halfwidth = int(np.ceil(len(x) / 60)) # Halfwidth of our smoothing convolution
|
||||
k = halfwidth
|
||||
xsmoo = x
|
||||
ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='same') / np.convolve(np.ones_like(y), np.ones(2 * k + 1),
|
||||
mode='same')
|
||||
return xsmoo, ysmoo
|
||||
|
||||
|
||||
def load_results(file):
|
||||
if not os.path.exists(file):
|
||||
return None
|
||||
with open(file, 'r') as f:
|
||||
lines = [line for line in f]
|
||||
if len(lines) < 2:
|
||||
return None
|
||||
keys = [name.strip() for name in lines[0].split(',')]
|
||||
data = np.genfromtxt(file, delimiter=',', skip_header=1, filling_values=0.)
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
assert data.ndim == 2
|
||||
assert data.shape[-1] == len(keys)
|
||||
result = {}
|
||||
for idx, key in enumerate(keys):
|
||||
result[key] = data[:, idx]
|
||||
return result
|
||||
|
||||
|
||||
def pad(xs, value=np.nan):
|
||||
maxlen = np.max([len(x) for x in xs])
|
||||
|
||||
padded_xs = []
|
||||
for x in xs:
|
||||
if x.shape[0] >= maxlen:
|
||||
padded_xs.append(x)
|
||||
|
||||
padding = np.ones((maxlen - x.shape[0],) + x.shape[1:]) * value
|
||||
x_padded = np.concatenate([x, padding], axis=0)
|
||||
assert x_padded.shape[1:] == x.shape[1:]
|
||||
assert x_padded.shape[0] == maxlen
|
||||
padded_xs.append(x_padded)
|
||||
return np.array(padded_xs)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dir', type=str)
|
||||
parser.add_argument('--smooth', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load all data.
|
||||
data = {}
|
||||
paths = [os.path.abspath(os.path.join(path, '..')) for path in glob2.glob(os.path.join(args.dir, '**', 'progress.csv'))]
|
||||
for curr_path in paths:
|
||||
if not os.path.isdir(curr_path):
|
||||
continue
|
||||
results = load_results(os.path.join(curr_path, 'progress.csv'))
|
||||
if not results:
|
||||
print('skipping {}'.format(curr_path))
|
||||
continue
|
||||
print('loading {} ({})'.format(curr_path, len(results['epoch'])))
|
||||
with open(os.path.join(curr_path, 'metadata.json'), 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
success_rate = np.array(results['test/success_rate'])
|
||||
epoch = np.array(results['epoch']) + 1
|
||||
env_id = metadata['kwargs']['env_name']
|
||||
replay_strategy = metadata['kwargs']['replay_strategy']
|
||||
|
||||
if replay_strategy == 'future':
|
||||
config = 'her'
|
||||
else:
|
||||
config = 'ddpg'
|
||||
if 'Dense' in env_id:
|
||||
config += '-dense'
|
||||
else:
|
||||
config += '-sparse'
|
||||
env_id = env_id.replace('Dense', '')
|
||||
|
||||
# Process and smooth data.
|
||||
assert success_rate.shape == epoch.shape
|
||||
x = epoch
|
||||
y = success_rate
|
||||
if args.smooth:
|
||||
x, y = smooth_reward_curve(epoch, success_rate)
|
||||
assert x.shape == y.shape
|
||||
|
||||
if env_id not in data:
|
||||
data[env_id] = {}
|
||||
if config not in data[env_id]:
|
||||
data[env_id][config] = []
|
||||
data[env_id][config].append((x, y))
|
||||
|
||||
# Plot data.
|
||||
for env_id in sorted(data.keys()):
|
||||
print('exporting {}'.format(env_id))
|
||||
plt.clf()
|
||||
|
||||
for config in sorted(data[env_id].keys()):
|
||||
xs, ys = zip(*data[env_id][config])
|
||||
xs, ys = pad(xs), pad(ys)
|
||||
assert xs.shape == ys.shape
|
||||
|
||||
plt.plot(xs[0], np.nanmedian(ys, axis=0), label=config)
|
||||
plt.fill_between(xs[0], np.nanpercentile(ys, 25, axis=0), np.nanpercentile(ys, 75, axis=0), alpha=0.25)
|
||||
plt.title(env_id)
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Median Success Rate')
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(args.dir, 'fig_{}.png'.format(env_id)))
|
172
baselines/her/experiment/train.py
Normal file
172
baselines/her/experiment/train.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import json
|
||||
from mpi4py import MPI
|
||||
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines.common.mpi_moments import mpi_moments
|
||||
import baselines.her.experiment.config as config
|
||||
from baselines.her.rollout import RolloutWorker
|
||||
from baselines.her.util import mpi_fork
|
||||
|
||||
|
||||
def mpi_average(value):
|
||||
if value == []:
|
||||
value = [0.]
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
return mpi_moments(np.array(value))[0]
|
||||
|
||||
|
||||
def train(policy, rollout_worker, evaluator,
|
||||
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
|
||||
save_policies, **kwargs):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
|
||||
best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl')
|
||||
periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl')
|
||||
|
||||
logger.info("Training...")
|
||||
best_success_rate = -1
|
||||
for epoch in range(n_epochs):
|
||||
# train
|
||||
rollout_worker.clear_history()
|
||||
for _ in range(n_cycles):
|
||||
episode = rollout_worker.generate_rollouts()
|
||||
policy.store_episode(episode)
|
||||
for _ in range(n_batches):
|
||||
policy.train()
|
||||
policy.update_target_net()
|
||||
|
||||
# test
|
||||
evaluator.clear_history()
|
||||
for _ in range(n_test_rollouts):
|
||||
evaluator.generate_rollouts()
|
||||
|
||||
# record logs
|
||||
logger.record_tabular('epoch', epoch)
|
||||
for key, val in evaluator.logs('test'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in rollout_worker.logs('train'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in policy.logs():
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
|
||||
if rank == 0:
|
||||
logger.dump_tabular()
|
||||
|
||||
# save the policy if it's better than the previous ones
|
||||
success_rate = mpi_average(evaluator.current_success_rate())
|
||||
if rank == 0 and success_rate >= best_success_rate and save_policies:
|
||||
best_success_rate = success_rate
|
||||
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
|
||||
evaluator.save_policy(best_policy_path)
|
||||
evaluator.save_policy(latest_policy_path)
|
||||
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies:
|
||||
policy_path = periodic_policy_path.format(epoch)
|
||||
logger.info('Saving periodic policy to {} ...'.format(policy_path))
|
||||
evaluator.save_policy(policy_path)
|
||||
|
||||
# make sure that different threads have different seeds
|
||||
local_uniform = np.random.uniform(size=(1,))
|
||||
root_uniform = local_uniform.copy()
|
||||
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
|
||||
if rank != 0:
|
||||
assert local_uniform[0] != root_uniform[0]
|
||||
|
||||
|
||||
def launch(
|
||||
env_name, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return,
|
||||
override_params={}, save_policies=True
|
||||
):
|
||||
# Fork for multi-CPU MPI implementation.
|
||||
if num_cpu > 1:
|
||||
whoami = mpi_fork(num_cpu)
|
||||
if whoami == 'parent':
|
||||
sys.exit(0)
|
||||
import baselines.common.tf_util as U
|
||||
U.single_threaded_session().__enter__()
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
# Configure logging
|
||||
if rank == 0:
|
||||
if logdir or logger.get_dir() is None:
|
||||
logger.configure(dir=logdir)
|
||||
else:
|
||||
logger.configure()
|
||||
logdir = logger.get_dir()
|
||||
assert logdir is not None
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Seed everything.
|
||||
rank_seed = seed + 1000000 * rank
|
||||
set_global_seeds(rank_seed)
|
||||
|
||||
# Prepare params.
|
||||
params = config.DEFAULT_PARAMS
|
||||
params['env_name'] = env_name
|
||||
params['replay_strategy'] = replay_strategy
|
||||
if env_name in config.DEFAULT_ENV_PARAMS:
|
||||
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
|
||||
params.update(**override_params) # makes it possible to override any parameter
|
||||
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
|
||||
json.dump(params, f)
|
||||
params = config.prepare_params(params)
|
||||
config.log_params(params, logger=logger)
|
||||
|
||||
dims = config.configure_dims(params)
|
||||
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
|
||||
|
||||
rollout_params = {
|
||||
'exploit': False,
|
||||
'use_target_net': False,
|
||||
'use_demo_states': True,
|
||||
'compute_Q': False,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
eval_params = {
|
||||
'exploit': True,
|
||||
'use_target_net': params['test_with_polyak'],
|
||||
'use_demo_states': False,
|
||||
'compute_Q': True,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
|
||||
rollout_params[name] = params[name]
|
||||
eval_params[name] = params[name]
|
||||
|
||||
rollout_worker = RolloutWorker(params['make_env'], policy, dims, logger, **rollout_params)
|
||||
rollout_worker.seed(rank_seed)
|
||||
|
||||
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
|
||||
evaluator.seed(rank_seed)
|
||||
|
||||
train(
|
||||
logdir=logdir, policy=policy, rollout_worker=rollout_worker,
|
||||
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
|
||||
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
|
||||
policy_save_interval=policy_save_interval, save_policies=save_policies)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--env_name', type=str, default='FetchReach-v0', help='the name of the OpenAI Gym environment that you want to train on')
|
||||
@click.option('--logdir', type=str, default=None, help='the path to where logs and policy pickles should go. If not specified, creates a folder in /tmp/')
|
||||
@click.option('--n_epochs', type=int, default=50, help='the number of training epochs to run')
|
||||
@click.option('--num_cpu', type=int, default=1, help='the number of CPU cores to use (using MPI)')
|
||||
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
|
||||
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
|
||||
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
|
||||
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
|
||||
def main(**kwargs):
|
||||
launch(**kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
63
baselines/her/her.py
Normal file
63
baselines/her/her.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_sample_her_transitions(replay_strategy, replay_k, reward_fun):
|
||||
"""Creates a sample function that can be used for HER experience replay.
|
||||
|
||||
Args:
|
||||
replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none',
|
||||
regular DDPG experience replay is used
|
||||
replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times
|
||||
as many HER replays as regular replays are used)
|
||||
reward_fun (function): function to re-compute the reward with substituted goals
|
||||
"""
|
||||
if replay_strategy == 'future':
|
||||
future_p = 1 - (1. / (1 + replay_k))
|
||||
else: # 'replay_strategy' == 'none'
|
||||
future_p = 0
|
||||
|
||||
def _sample_her_transitions(episode_batch, batch_size_in_transitions):
|
||||
"""episode_batch is {key: array(buffer_size x T x dim_key)}
|
||||
"""
|
||||
T = episode_batch['u'].shape[1]
|
||||
rollout_batch_size = episode_batch['u'].shape[0]
|
||||
batch_size = batch_size_in_transitions
|
||||
|
||||
# Select which episodes and time steps to use.
|
||||
episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
|
||||
t_samples = np.random.randint(T, size=batch_size)
|
||||
transitions = {key: episode_batch[key][episode_idxs, t_samples].copy()
|
||||
for key in episode_batch.keys()}
|
||||
|
||||
# Select future time indexes proportional with probability future_p. These
|
||||
# will be used for HER replay by substituting in future goals.
|
||||
her_indexes = np.where(np.random.uniform(size=batch_size) < future_p)
|
||||
future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
|
||||
future_offset = future_offset.astype(int)
|
||||
future_t = (t_samples + 1 + future_offset)[her_indexes]
|
||||
|
||||
# Replace goal with achieved goal but only for the previously-selected
|
||||
# HER transitions (as defined by her_indexes). For the other transitions,
|
||||
# keep the original goal.
|
||||
future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
|
||||
transitions['g'][her_indexes] = future_ag
|
||||
|
||||
# Reconstruct info dictionary for reward computation.
|
||||
info = {}
|
||||
for key, value in transitions.items():
|
||||
if key.startswith('info_'):
|
||||
info[key.replace('info_', '')] = value
|
||||
|
||||
# Re-compute reward since we may have substituted the goal.
|
||||
reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
|
||||
reward_params['info'] = info
|
||||
transitions['r'] = reward_fun(**reward_params)
|
||||
|
||||
transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
|
||||
for k in transitions.keys()}
|
||||
|
||||
assert(transitions['u'].shape[0] == batch_size_in_transitions)
|
||||
|
||||
return transitions
|
||||
|
||||
return _sample_her_transitions
|
140
baselines/her/normalizer.py
Normal file
140
baselines/her/normalizer.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
from mpi4py import MPI
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.her.util import reshape_for_broadcasting
|
||||
|
||||
|
||||
class Normalizer:
|
||||
def __init__(self, size, eps=1e-2, default_clip_range=np.inf, sess=None):
|
||||
"""A normalizer that ensures that observations are approximately distributed according to
|
||||
a standard Normal distribution (i.e. have mean zero and variance one).
|
||||
|
||||
Args:
|
||||
size (int): the size of the observation to be normalized
|
||||
eps (float): a small constant that avoids underflows
|
||||
default_clip_range (float): normalized observations are clipped to be in
|
||||
[-default_clip_range, default_clip_range]
|
||||
sess (object): the TensorFlow session to be used
|
||||
"""
|
||||
self.size = size
|
||||
self.eps = eps
|
||||
self.default_clip_range = default_clip_range
|
||||
self.sess = sess if sess is not None else tf.get_default_session()
|
||||
|
||||
self.local_sum = np.zeros(self.size, np.float32)
|
||||
self.local_sumsq = np.zeros(self.size, np.float32)
|
||||
self.local_count = np.zeros(1, np.float32)
|
||||
|
||||
self.sum_tf = tf.get_variable(
|
||||
initializer=tf.zeros_initializer(), shape=self.local_sum.shape, name='sum',
|
||||
trainable=False, dtype=tf.float32)
|
||||
self.sumsq_tf = tf.get_variable(
|
||||
initializer=tf.zeros_initializer(), shape=self.local_sumsq.shape, name='sumsq',
|
||||
trainable=False, dtype=tf.float32)
|
||||
self.count_tf = tf.get_variable(
|
||||
initializer=tf.ones_initializer(), shape=self.local_count.shape, name='count',
|
||||
trainable=False, dtype=tf.float32)
|
||||
self.mean = tf.get_variable(
|
||||
initializer=tf.zeros_initializer(), shape=(self.size,), name='mean',
|
||||
trainable=False, dtype=tf.float32)
|
||||
self.std = tf.get_variable(
|
||||
initializer=tf.ones_initializer(), shape=(self.size,), name='std',
|
||||
trainable=False, dtype=tf.float32)
|
||||
self.count_pl = tf.placeholder(name='count_pl', shape=(1,), dtype=tf.float32)
|
||||
self.sum_pl = tf.placeholder(name='sum_pl', shape=(self.size,), dtype=tf.float32)
|
||||
self.sumsq_pl = tf.placeholder(name='sumsq_pl', shape=(self.size,), dtype=tf.float32)
|
||||
|
||||
self.update_op = tf.group(
|
||||
self.count_tf.assign_add(self.count_pl),
|
||||
self.sum_tf.assign_add(self.sum_pl),
|
||||
self.sumsq_tf.assign_add(self.sumsq_pl)
|
||||
)
|
||||
self.recompute_op = tf.group(
|
||||
tf.assign(self.mean, self.sum_tf / self.count_tf),
|
||||
tf.assign(self.std, tf.sqrt(tf.maximum(
|
||||
tf.square(self.eps),
|
||||
self.sumsq_tf / self.count_tf - tf.square(self.sum_tf / self.count_tf)
|
||||
))),
|
||||
)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def update(self, v):
|
||||
v = v.reshape(-1, self.size)
|
||||
|
||||
with self.lock:
|
||||
self.local_sum += v.sum(axis=0)
|
||||
self.local_sumsq += (np.square(v)).sum(axis=0)
|
||||
self.local_count[0] += v.shape[0]
|
||||
|
||||
def normalize(self, v, clip_range=None):
|
||||
if clip_range is None:
|
||||
clip_range = self.default_clip_range
|
||||
mean = reshape_for_broadcasting(self.mean, v)
|
||||
std = reshape_for_broadcasting(self.std, v)
|
||||
return tf.clip_by_value((v - mean) / std, -clip_range, clip_range)
|
||||
|
||||
def denormalize(self, v):
|
||||
mean = reshape_for_broadcasting(self.mean, v)
|
||||
std = reshape_for_broadcasting(self.std, v)
|
||||
return mean + v * std
|
||||
|
||||
def _mpi_average(self, x):
|
||||
buf = np.zeros_like(x)
|
||||
MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM)
|
||||
buf /= MPI.COMM_WORLD.Get_size()
|
||||
return buf
|
||||
|
||||
def synchronize(self, local_sum, local_sumsq, local_count, root=None):
|
||||
local_sum[...] = self._mpi_average(local_sum)
|
||||
local_sumsq[...] = self._mpi_average(local_sumsq)
|
||||
local_count[...] = self._mpi_average(local_count)
|
||||
return local_sum, local_sumsq, local_count
|
||||
|
||||
def recompute_stats(self):
|
||||
with self.lock:
|
||||
# Copy over results.
|
||||
local_count = self.local_count.copy()
|
||||
local_sum = self.local_sum.copy()
|
||||
local_sumsq = self.local_sumsq.copy()
|
||||
|
||||
# Reset.
|
||||
self.local_count[...] = 0
|
||||
self.local_sum[...] = 0
|
||||
self.local_sumsq[...] = 0
|
||||
|
||||
# We perform the synchronization outside of the lock to keep the critical section as short
|
||||
# as possible.
|
||||
synced_sum, synced_sumsq, synced_count = self.synchronize(
|
||||
local_sum=local_sum, local_sumsq=local_sumsq, local_count=local_count)
|
||||
|
||||
self.sess.run(self.update_op, feed_dict={
|
||||
self.count_pl: synced_count,
|
||||
self.sum_pl: synced_sum,
|
||||
self.sumsq_pl: synced_sumsq,
|
||||
})
|
||||
self.sess.run(self.recompute_op)
|
||||
|
||||
|
||||
class IdentityNormalizer:
|
||||
def __init__(self, size, std=1.):
|
||||
self.size = size
|
||||
self.mean = tf.zeros(self.size, tf.float32)
|
||||
self.std = std * tf.ones(self.size, tf.float32)
|
||||
|
||||
def update(self, x):
|
||||
pass
|
||||
|
||||
def normalize(self, x, clip_range=None):
|
||||
return x / self.std
|
||||
|
||||
def denormalize(self, x):
|
||||
return self.std * x
|
||||
|
||||
def synchronize(self):
|
||||
pass
|
||||
|
||||
def recompute_stats(self):
|
||||
pass
|
108
baselines/her/replay_buffer.py
Normal file
108
baselines/her/replay_buffer.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(self, buffer_shapes, size_in_transitions, T, sample_transitions):
|
||||
"""Creates a replay buffer.
|
||||
|
||||
Args:
|
||||
buffer_shapes (dict of ints): the shape for all buffers that are used in the replay
|
||||
buffer
|
||||
size_in_transitions (int): the size of the buffer, measured in transitions
|
||||
T (int): the time horizon for episodes
|
||||
sample_transitions (function): a function that samples from the replay buffer
|
||||
"""
|
||||
self.buffer_shapes = buffer_shapes
|
||||
self.size = size_in_transitions // T
|
||||
self.T = T
|
||||
self.sample_transitions = sample_transitions
|
||||
|
||||
# self.buffers is {key: array(size_in_episodes x T or T+1 x dim_key)}
|
||||
self.buffers = {key: np.empty([self.size, *shape])
|
||||
for key, shape in buffer_shapes.items()}
|
||||
|
||||
# memory management
|
||||
self.current_size = 0
|
||||
self.n_transitions_stored = 0
|
||||
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def full(self):
|
||||
with self.lock:
|
||||
return self.current_size == self.size
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""Returns a dict {key: array(batch_size x shapes[key])}
|
||||
"""
|
||||
buffers = {}
|
||||
|
||||
with self.lock:
|
||||
assert self.current_size > 0
|
||||
for key in self.buffers.keys():
|
||||
buffers[key] = self.buffers[key][:self.current_size]
|
||||
|
||||
buffers['o_2'] = buffers['o'][:, 1:, :]
|
||||
buffers['ag_2'] = buffers['ag'][:, 1:, :]
|
||||
|
||||
transitions = self.sample_transitions(buffers, batch_size)
|
||||
|
||||
for key in (['r', 'o_2', 'ag_2'] + list(self.buffers.keys())):
|
||||
assert key in transitions, "key %s missing from transitions" % key
|
||||
|
||||
return transitions
|
||||
|
||||
def store_episode(self, episode_batch):
|
||||
"""episode_batch: array(batch_size x (T or T+1) x dim_key)
|
||||
"""
|
||||
batch_sizes = [len(episode_batch[key]) for key in episode_batch.keys()]
|
||||
assert np.all(np.array(batch_sizes) == batch_sizes[0])
|
||||
batch_size = batch_sizes[0]
|
||||
|
||||
with self.lock:
|
||||
idxs = self._get_storage_idx(batch_size)
|
||||
|
||||
# load inputs into buffers
|
||||
for key in self.buffers.keys():
|
||||
self.buffers[key][idxs] = episode_batch[key]
|
||||
|
||||
self.n_transitions_stored += batch_size * self.T
|
||||
|
||||
def get_current_episode_size(self):
|
||||
with self.lock:
|
||||
return self.current_size
|
||||
|
||||
def get_current_size(self):
|
||||
with self.lock:
|
||||
return self.current_size * self.T
|
||||
|
||||
def get_transitions_stored(self):
|
||||
with self.lock:
|
||||
return self.n_transitions_stored
|
||||
|
||||
def clear_buffer(self):
|
||||
with self.lock:
|
||||
self.current_size = 0
|
||||
|
||||
def _get_storage_idx(self, inc=None):
|
||||
inc = inc or 1 # size increment
|
||||
assert inc <= self.size, "Batch committed to replay is too large!"
|
||||
# go consecutively until you hit the end, and then go randomly.
|
||||
if self.current_size+inc <= self.size:
|
||||
idx = np.arange(self.current_size, self.current_size+inc)
|
||||
elif self.current_size < self.size:
|
||||
overflow = inc - (self.size - self.current_size)
|
||||
idx_a = np.arange(self.current_size, self.size)
|
||||
idx_b = np.random.randint(0, self.current_size, overflow)
|
||||
idx = np.concatenate([idx_a, idx_b])
|
||||
else:
|
||||
idx = np.random.randint(0, self.size, inc)
|
||||
|
||||
# update replay size
|
||||
self.current_size = min(self.size, self.current_size+inc)
|
||||
|
||||
if inc == 1:
|
||||
idx = idx[0]
|
||||
return idx
|
188
baselines/her/rollout.py
Normal file
188
baselines/her/rollout.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from mujoco_py import MujocoException
|
||||
|
||||
from baselines.her.util import convert_episode_to_batch_major, store_args
|
||||
|
||||
|
||||
class RolloutWorker:
|
||||
|
||||
@store_args
|
||||
def __init__(self, make_env, policy, dims, logger, T, rollout_batch_size=1,
|
||||
exploit=False, use_target_net=False, compute_Q=False, noise_eps=0,
|
||||
random_eps=0, history_len=100, render=False, **kwargs):
|
||||
"""Rollout worker generates experience by interacting with one or many environments.
|
||||
|
||||
Args:
|
||||
make_env (function): a factory function that creates a new instance of the environment
|
||||
when called
|
||||
policy (object): the policy that is used to act
|
||||
dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u)
|
||||
logger (object): the logger that is used by the rollout worker
|
||||
rollout_batch_size (int): the number of parallel rollouts that should be used
|
||||
exploit (boolean): whether or not to exploit, i.e. to act optimally according to the
|
||||
current policy without any exploration
|
||||
use_target_net (boolean): whether or not to use the target net for rollouts
|
||||
compute_Q (boolean): whether or not to compute the Q values alongside the actions
|
||||
noise_eps (float): scale of the additive Gaussian noise
|
||||
random_eps (float): probability of selecting a completely random action
|
||||
history_len (int): length of history for statistics smoothing
|
||||
render (boolean): whether or not to render the rollouts
|
||||
"""
|
||||
self.envs = [make_env() for _ in range(rollout_batch_size)]
|
||||
assert self.T > 0
|
||||
|
||||
self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')]
|
||||
|
||||
self.success_history = deque(maxlen=history_len)
|
||||
self.Q_history = deque(maxlen=history_len)
|
||||
|
||||
self.n_episodes = 0
|
||||
self.g = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals
|
||||
self.initial_o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
|
||||
self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
|
||||
self.reset_all_rollouts()
|
||||
self.clear_history()
|
||||
|
||||
def reset_rollout(self, i):
|
||||
"""Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o`
|
||||
and `g` arrays accordingly.
|
||||
"""
|
||||
obs = self.envs[i].reset()
|
||||
self.initial_o[i] = obs['observation']
|
||||
self.initial_ag[i] = obs['achieved_goal']
|
||||
self.g[i] = obs['desired_goal']
|
||||
|
||||
def reset_all_rollouts(self):
|
||||
"""Resets all `rollout_batch_size` rollout workers.
|
||||
"""
|
||||
for i in range(self.rollout_batch_size):
|
||||
self.reset_rollout(i)
|
||||
|
||||
def generate_rollouts(self):
|
||||
"""Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current
|
||||
policy acting on it accordingly.
|
||||
"""
|
||||
self.reset_all_rollouts()
|
||||
|
||||
# compute observations
|
||||
o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
|
||||
ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
|
||||
o[:] = self.initial_o
|
||||
ag[:] = self.initial_ag
|
||||
|
||||
# generate episodes
|
||||
obs, achieved_goals, acts, goals, successes = [], [], [], [], []
|
||||
info_values = [np.empty((self.T, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
|
||||
Qs = []
|
||||
for t in range(self.T):
|
||||
policy_output = self.policy.get_actions(
|
||||
o, ag, self.g,
|
||||
compute_Q=self.compute_Q,
|
||||
noise_eps=self.noise_eps if not self.exploit else 0.,
|
||||
random_eps=self.random_eps if not self.exploit else 0.,
|
||||
use_target_net=self.use_target_net)
|
||||
|
||||
if self.compute_Q:
|
||||
u, Q = policy_output
|
||||
Qs.append(Q)
|
||||
else:
|
||||
u = policy_output
|
||||
|
||||
if u.ndim == 1:
|
||||
# The non-batched case should still have a reasonable shape.
|
||||
u = u.reshape(1, -1)
|
||||
|
||||
o_new = np.empty((self.rollout_batch_size, self.dims['o']))
|
||||
ag_new = np.empty((self.rollout_batch_size, self.dims['g']))
|
||||
success = np.zeros(self.rollout_batch_size)
|
||||
# compute new states and observations
|
||||
for i in range(self.rollout_batch_size):
|
||||
try:
|
||||
# We fully ignore the reward here because it will have to be re-computed
|
||||
# for HER.
|
||||
curr_o_new, _, _, info = self.envs[i].step(u[i])
|
||||
if 'is_success' in info:
|
||||
success[i] = info['is_success']
|
||||
o_new[i] = curr_o_new['observation']
|
||||
ag_new[i] = curr_o_new['achieved_goal']
|
||||
for idx, key in enumerate(self.info_keys):
|
||||
info_values[idx][t, i] = info[key]
|
||||
if self.render:
|
||||
self.envs[i].render()
|
||||
except MujocoException as e:
|
||||
return self.generate_rollouts()
|
||||
|
||||
if np.isnan(o_new).any():
|
||||
self.logger.warning('NaN caught during rollout generation. Trying again...')
|
||||
self.reset_all_rollouts()
|
||||
return self.generate_rollouts()
|
||||
|
||||
obs.append(o.copy())
|
||||
achieved_goals.append(ag.copy())
|
||||
successes.append(success.copy())
|
||||
acts.append(u.copy())
|
||||
goals.append(self.g.copy())
|
||||
o[...] = o_new
|
||||
ag[...] = ag_new
|
||||
obs.append(o.copy())
|
||||
achieved_goals.append(ag.copy())
|
||||
self.initial_o[:] = o
|
||||
|
||||
episode = dict(o=obs,
|
||||
u=acts,
|
||||
g=goals,
|
||||
ag=achieved_goals)
|
||||
for key, value in zip(self.info_keys, info_values):
|
||||
episode['info_{}'.format(key)] = value
|
||||
|
||||
# stats
|
||||
successful = np.array(successes)[-1, :]
|
||||
assert successful.shape == (self.rollout_batch_size,)
|
||||
success_rate = np.mean(successful)
|
||||
self.success_history.append(success_rate)
|
||||
if self.compute_Q:
|
||||
self.Q_history.append(np.mean(Qs))
|
||||
self.n_episodes += self.rollout_batch_size
|
||||
|
||||
return convert_episode_to_batch_major(episode)
|
||||
|
||||
def clear_history(self):
|
||||
"""Clears all histories that are used for statistics
|
||||
"""
|
||||
self.success_history.clear()
|
||||
self.Q_history.clear()
|
||||
|
||||
def current_success_rate(self):
|
||||
return np.mean(self.success_history)
|
||||
|
||||
def current_mean_Q(self):
|
||||
return np.mean(self.Q_history)
|
||||
|
||||
def save_policy(self, path):
|
||||
"""Pickles the current policy for later inspection.
|
||||
"""
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.policy, f)
|
||||
|
||||
def logs(self, prefix='worker'):
|
||||
"""Generates a dictionary that contains all collected statistics.
|
||||
"""
|
||||
logs = []
|
||||
logs += [('success_rate', np.mean(self.success_history))]
|
||||
if self.compute_Q:
|
||||
logs += [('mean_Q', np.mean(self.Q_history))]
|
||||
logs += [('episode', self.n_episodes)]
|
||||
|
||||
if prefix is not '' and not prefix.endswith('/'):
|
||||
return [(prefix + '/' + key, val) for key, val in logs]
|
||||
else:
|
||||
return logs
|
||||
|
||||
def seed(self, seed):
|
||||
"""Seeds each environment with a distinct seed derived from the passed in global seed.
|
||||
"""
|
||||
for idx, env in enumerate(self.envs):
|
||||
env.seed(seed + 1000 * idx)
|
144
baselines/her/util.py
Normal file
144
baselines/her/util.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import importlib
|
||||
import inspect
|
||||
import functools
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common import tf_util as U
|
||||
|
||||
|
||||
def store_args(method):
|
||||
"""Stores provided method args as instance attributes.
|
||||
"""
|
||||
argspec = inspect.getfullargspec(method)
|
||||
defaults = {}
|
||||
if argspec.defaults is not None:
|
||||
defaults = dict(
|
||||
zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
|
||||
if argspec.kwonlydefaults is not None:
|
||||
defaults.update(argspec.kwonlydefaults)
|
||||
arg_names = argspec.args[1:]
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(*positional_args, **keyword_args):
|
||||
self = positional_args[0]
|
||||
# Get default arg values
|
||||
args = defaults.copy()
|
||||
# Add provided arg values
|
||||
for name, value in zip(arg_names, positional_args[1:]):
|
||||
args[name] = value
|
||||
args.update(keyword_args)
|
||||
self.__dict__.update(args)
|
||||
return method(*positional_args, **keyword_args)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def import_function(spec):
|
||||
"""Import a function identified by a string like "pkg.module:fn_name".
|
||||
"""
|
||||
mod_name, fn_name = spec.split(':')
|
||||
module = importlib.import_module(mod_name)
|
||||
fn = getattr(module, fn_name)
|
||||
return fn
|
||||
|
||||
|
||||
def flatten_grads(var_list, grads):
|
||||
"""Flattens a variables and their gradients.
|
||||
"""
|
||||
return tf.concat([tf.reshape(grad, [U.numel(v)])
|
||||
for (v, grad) in zip(var_list, grads)], 0)
|
||||
|
||||
|
||||
def nn(input, layers_sizes, reuse=None, flatten=False, name=""):
|
||||
"""Creates a simple neural network
|
||||
"""
|
||||
for i, size in enumerate(layers_sizes):
|
||||
activation = tf.nn.relu if i < len(layers_sizes)-1 else None
|
||||
input = tf.layers.dense(inputs=input,
|
||||
units=size,
|
||||
kernel_initializer=tf.contrib.layers.xavier_initializer(),
|
||||
reuse=reuse,
|
||||
name=name+'_'+str(i))
|
||||
if activation:
|
||||
input = activation(input)
|
||||
if flatten:
|
||||
assert layers_sizes[-1] == 1
|
||||
input = tf.reshape(input, [-1])
|
||||
return input
|
||||
|
||||
|
||||
def install_mpi_excepthook():
|
||||
import sys
|
||||
from mpi4py import MPI
|
||||
old_hook = sys.excepthook
|
||||
|
||||
def new_hook(a, b, c):
|
||||
old_hook(a, b, c)
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
MPI.COMM_WORLD.Abort()
|
||||
sys.excepthook = new_hook
|
||||
|
||||
|
||||
def mpi_fork(n):
|
||||
"""Re-launches the current script with workers
|
||||
Returns "parent" for original parent, "child" for MPI children
|
||||
"""
|
||||
if n <= 1:
|
||||
return "child"
|
||||
if os.getenv("IN_MPI") is None:
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
MKL_NUM_THREADS="1",
|
||||
OMP_NUM_THREADS="1",
|
||||
IN_MPI="1"
|
||||
)
|
||||
# "-bind-to core" is crucial for good performance
|
||||
args = [
|
||||
"mpirun",
|
||||
"-np",
|
||||
str(n),
|
||||
"-bind-to",
|
||||
"core",
|
||||
sys.executable
|
||||
]
|
||||
args += sys.argv
|
||||
subprocess.check_call(args, env=env)
|
||||
return "parent"
|
||||
else:
|
||||
install_mpi_excepthook()
|
||||
return "child"
|
||||
|
||||
|
||||
def convert_episode_to_batch_major(episode):
|
||||
"""Converts an episode to have the batch dimension in the major (first)
|
||||
dimension.
|
||||
"""
|
||||
episode_batch = {}
|
||||
for key in episode.keys():
|
||||
val = np.array(episode[key]).copy()
|
||||
# make inputs batch-major instead of time-major
|
||||
episode_batch[key] = val.swapaxes(0, 1)
|
||||
|
||||
return episode_batch
|
||||
|
||||
|
||||
def transitions_in_episode_batch(episode_batch):
|
||||
"""Number of transitions in a given episode batch.
|
||||
"""
|
||||
shape = episode_batch['u'].shape
|
||||
return shape[0] * shape[1]
|
||||
|
||||
|
||||
def reshape_for_broadcasting(source, target):
|
||||
"""Reshapes a tensor (source) to have the correct shape and dtype of the target
|
||||
before broadcasting it with MPI.
|
||||
"""
|
||||
dim = len(target.get_shape())
|
||||
shape = ([1] * (dim-1)) + [-1]
|
||||
return tf.reshape(tf.cast(source, target.dtype), shape)
|
@@ -6,6 +6,7 @@ import json
|
||||
import time
|
||||
import datetime
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
|
||||
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv']
|
||||
# Also valid: json, tensorboard
|
||||
@@ -124,7 +125,7 @@ class CSVOutputFormat(KVWriter):
|
||||
if i > 0:
|
||||
self.file.write(',')
|
||||
v = kvs.get(k)
|
||||
if v:
|
||||
if v is not None:
|
||||
self.file.write(str(v))
|
||||
self.file.write('\n')
|
||||
self.file.flush()
|
||||
@@ -168,24 +169,18 @@ class TensorBoardOutputFormat(KVWriter):
|
||||
self.writer.Close()
|
||||
self.writer = None
|
||||
|
||||
def make_output_format(format, ev_dir):
|
||||
from mpi4py import MPI
|
||||
def make_output_format(format, ev_dir, log_suffix=''):
|
||||
os.makedirs(ev_dir, exist_ok=True)
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
if format == 'stdout':
|
||||
return HumanOutputFormat(sys.stdout)
|
||||
elif format == 'log':
|
||||
suffix = "" if rank==0 else ("-mpi%03i"%rank)
|
||||
return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % suffix))
|
||||
return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
|
||||
elif format == 'json':
|
||||
assert rank==0
|
||||
return JSONOutputFormat(osp.join(ev_dir, 'progress.json'))
|
||||
return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
|
||||
elif format == 'csv':
|
||||
assert rank==0
|
||||
return CSVOutputFormat(osp.join(ev_dir, 'progress.csv'))
|
||||
return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
|
||||
elif format == 'tensorboard':
|
||||
assert rank==0
|
||||
return TensorBoardOutputFormat(osp.join(ev_dir, 'tb'))
|
||||
return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
|
||||
else:
|
||||
raise ValueError('Unknown format specified: %s' % (format,))
|
||||
|
||||
@@ -197,9 +192,16 @@ def logkv(key, val):
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
If called many times, last value will be used.
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
|
||||
def logkv_mean(key, val):
|
||||
"""
|
||||
The same as logkv(), but if called many times, values averaged.
|
||||
"""
|
||||
Logger.CURRENT.logkv_mean(key, val)
|
||||
|
||||
def logkvs(d):
|
||||
"""
|
||||
Log a dictionary of key-value pairs
|
||||
@@ -255,6 +257,33 @@ def get_dir():
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
|
||||
class ProfileKV:
|
||||
"""
|
||||
Usage:
|
||||
with logger.ProfileKV("interesting_scope"):
|
||||
code
|
||||
"""
|
||||
def __init__(self, n):
|
||||
self.n = "wait_" + n
|
||||
def __enter__(self):
|
||||
self.t1 = time.time()
|
||||
def __exit__(self ,type, value, traceback):
|
||||
Logger.CURRENT.name2val[self.n] += time.time() - self.t1
|
||||
|
||||
def profile(n):
|
||||
"""
|
||||
Usage:
|
||||
@profile("my_func")
|
||||
def my_func(): code
|
||||
"""
|
||||
def decorator_with_name(func):
|
||||
def func_wrapper(*args, **kwargs):
|
||||
with ProfileKV(n):
|
||||
return func(*args, **kwargs)
|
||||
return func_wrapper
|
||||
return decorator_with_name
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Backend
|
||||
# ================================================================
|
||||
@@ -265,7 +294,8 @@ class Logger(object):
|
||||
CURRENT = None # Current logger being used by the free functions above
|
||||
|
||||
def __init__(self, dir, output_formats):
|
||||
self.name2val = {} # values this iteration
|
||||
self.name2val = defaultdict(float) # values this iteration
|
||||
self.name2cnt = defaultdict(int)
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
@@ -275,12 +305,21 @@ class Logger(object):
|
||||
def logkv(self, key, val):
|
||||
self.name2val[key] = val
|
||||
|
||||
def logkv_mean(self, key, val):
|
||||
if val is None:
|
||||
self.name2val[key] = None
|
||||
return
|
||||
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
||||
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
|
||||
self.name2cnt[key] = cnt + 1
|
||||
|
||||
def dumpkvs(self):
|
||||
if self.level == DISABLED: return
|
||||
for fmt in self.output_formats:
|
||||
if isinstance(fmt, KVWriter):
|
||||
fmt.writekvs(self.name2val)
|
||||
self.name2val.clear()
|
||||
self.name2cnt.clear()
|
||||
|
||||
def log(self, *args, level=INFO):
|
||||
if self.level <= level:
|
||||
@@ -360,6 +399,11 @@ def _demo():
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
info("^^^ should see a = 5.5")
|
||||
logkv_mean("b", -22.5)
|
||||
logkv_mean("b", -44.4)
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
info("^^^ should see b = 33.3")
|
||||
|
||||
logkv("b", -2.5)
|
||||
dumpkvs()
|
||||
|
6
setup.py
6
setup.py
@@ -10,7 +10,7 @@ setup(name='baselines',
|
||||
packages=[package for package in find_packages()
|
||||
if package.startswith('baselines')],
|
||||
install_requires=[
|
||||
'gym[mujoco,atari,classic_control]',
|
||||
'gym[mujoco,atari,classic_control,robotics]',
|
||||
'scipy',
|
||||
'tqdm',
|
||||
'joblib',
|
||||
@@ -19,9 +19,11 @@ setup(name='baselines',
|
||||
'progressbar2',
|
||||
'mpi4py',
|
||||
'cloudpickle',
|
||||
'tensorflow>=1.4.0',
|
||||
'click',
|
||||
],
|
||||
description='OpenAI baselines: high quality implementations of reinforcement learning algorithms',
|
||||
author='OpenAI',
|
||||
url='https://github.com/openai/baselines',
|
||||
author_email='gym@openai.com',
|
||||
version='0.1.4')
|
||||
version='0.1.5')
|
||||
|
Reference in New Issue
Block a user