diff --git a/README.md b/README.md index f1d85e1..cbda551 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ pip install -e . - [DDPG](baselines/ddpg) - [DQN](baselines/deepq) - [GAIL](baselines/gail) +- [HER](baselines/her) - [PPO1](baselines/ppo1) (Multi-CPU using MPI) - [PPO2](baselines/ppo2) (Optimized for GPU) - [TRPO](baselines/trpo_mpi) diff --git a/baselines/a2c/utils.py b/baselines/a2c/utils.py index 3c362ec..0964af8 100644 --- a/baselines/a2c/utils.py +++ b/baselines/a2c/utils.py @@ -39,12 +39,24 @@ def ortho_init(scale=1.0): return (scale * q[:shape[0], :shape[1]]).astype(np.float32) return _ortho_init -def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0): +def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC'): + if data_format == 'NHWC': + channel_ax = 3 + strides = [1, stride, stride, 1] + bshape = [1, 1, 1, nf] + elif data_format == 'NCHW': + channel_ax = 1 + strides = [1, 1, stride, stride] + bshape = [1, nf, 1, 1] + else: + raise NotImplementedError + nin = x.get_shape()[channel_ax].value + wshape = [rf, rf, nin, nf] with tf.variable_scope(scope): - nin = x.get_shape()[3].value - w = tf.get_variable("w", [rf, rf, nin, nf], initializer=ortho_init(init_scale)) - b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0.0)) - return tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b + w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale)) + b = tf.get_variable("b", [1, nf, 1, 1], initializer=tf.constant_initializer(0.0)) + if data_format == 'NHWC': b = tf.reshape(b, bshape) + return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): with tf.variable_scope(scope): diff --git a/baselines/bench/monitor.py b/baselines/bench/monitor.py index 73c127c..0da1b4f 100644 --- a/baselines/bench/monitor.py +++ b/baselines/bench/monitor.py @@ -7,12 +7,13 @@ from glob import glob import csv import os.path as osp import json +import numpy as np class Monitor(Wrapper): EXT = "monitor.csv" f = None - def __init__(self, env, filename, allow_early_resets=False, reset_keywords=()): + def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()): Wrapper.__init__(self, env=env) self.tstart = time.time() if filename is None: @@ -26,10 +27,12 @@ class Monitor(Wrapper): filename = filename + "." + Monitor.EXT self.f = open(filename, "wt") self.f.write('#%s\n'%json.dumps({"t_start": self.tstart, 'env_id' : env.spec and env.spec.id})) - self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords) + self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords+info_keywords) self.logger.writeheader() + self.f.flush() self.reset_keywords = reset_keywords + self.info_keywords = info_keywords self.allow_early_resets = allow_early_resets self.rewards = None self.needs_reset = True @@ -61,6 +64,8 @@ class Monitor(Wrapper): eprew = sum(self.rewards) eplen = len(self.rewards) epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)} + for k in self.info_keywords: + epinfo[k] = info[k] self.episode_rewards.append(eprew) self.episode_lengths.append(eplen) self.episode_times.append(time.time() - self.tstart) diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index d8a48ae..8a3304b 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -4,6 +4,7 @@ Helpers for scripts like run_atari.py. import os import gym +from gym.wrappers import FlattenDictWrapper from baselines import logger from baselines.bench import Monitor from baselines.common import set_global_seeds @@ -36,6 +37,19 @@ def make_mujoco_env(env_id, seed): env.seed(seed) return env +def make_robotics_env(env_id, seed, rank=0): + """ + Create a wrapped, monitored gym.Env for MuJoCo. + """ + set_global_seeds(seed) + env = gym.make(env_id) + env = FlattenDictWrapper(env, ['observation', 'desired_goal']) + env = Monitor( + env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), + info_keywords=('is_success',)) + env.seed(seed) + return env + def arg_parser(): """ Create an empty argparse.ArgumentParser. @@ -58,7 +72,17 @@ def mujoco_arg_parser(): Create an argparse.ArgumentParser for run_mujoco.py. """ parser = arg_parser() - parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1") + parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') + parser.add_argument('--seed', help='RNG seed', type=int, default=0) + parser.add_argument('--num-timesteps', type=int, default=int(1e6)) + return parser + +def robotics_arg_parser(): + """ + Create an argparse.ArgumentParser for run_mujoco.py. + """ + parser = arg_parser() + parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) return parser diff --git a/baselines/common/console_util.py b/baselines/common/console_util.py index 8b4ae8e..8adc3f8 100644 --- a/baselines/common/console_util.py +++ b/baselines/common/console_util.py @@ -16,7 +16,12 @@ def fmt_item(x, l): if isinstance(x, np.ndarray): assert x.ndim==0 x = x.item() - if isinstance(x, float): rep = "%g"%x + if isinstance(x, (float, np.float32, np.float64)): + v = abs(x) + if (v < 1e-4 or v > 1e+4) and v > 0: + rep = "%7.2e" % x + else: + rep = "%7.5f" % x else: rep = str(x) return " "*(l - len(rep)) + rep diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 6cd90b1..3396bb8 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -261,3 +261,20 @@ def get_placeholder_cached(name): def flattenallbut0(x): return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])]) + + +# ================================================================ +# Diagnostics +# ================================================================ + +def display_var_info(vars): + from baselines import logger + count_params = 0 + for v in vars: + name = v.name + if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue + count_params += np.prod(v.shape.as_list()) + if "/b:" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print + logger.info(" %s%s%s" % (name, " "*(55-len(name)), str(v.shape))) + logger.info("Total model parameters: %0.1f million" % (count_params*1e-6)) + diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index c70db68..a09e375 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -1,31 +1,42 @@ import numpy as np +import gym from . import VecEnv class DummyVecEnv(VecEnv): def __init__(self, env_fns): self.envs = [fn() for fn in env_fns] - env = self.envs[0] + env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) - self.ts = np.zeros(len(self.envs), dtype='int') + + obs_spaces = self.observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (self.observation_space,) + self.buf_obs = [np.zeros((self.num_envs,) + tuple(s.shape), s.dtype) for s in obs_spaces] + self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) + self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) + self.buf_infos = [{} for _ in range(self.num_envs)] self.actions = None def step_async(self, actions): self.actions = actions def step_wait(self): - results = [env.step(a) for (a,env) in zip(self.actions, self.envs)] - obs, rews, dones, infos = map(np.array, zip(*results)) - self.ts += 1 - for (i, done) in enumerate(dones): - if done: - obs[i] = self.envs[i].reset() - self.ts[i] = 0 - self.actions = None - return np.array(obs), np.array(rews), np.array(dones), infos + for i in range(self.num_envs): + obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i]) + if isinstance(obs_tuple, (tuple, list)): + for t,x in enumerate(obs_tuple): + self.buf_obs[t][i] = x + else: + self.buf_obs[0][i] = obs_tuple + return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos def reset(self): - results = [env.reset() for env in self.envs] - return np.array(results) + for i in range(self.num_envs): + obs_tuple = self.envs[i].reset() + if isinstance(obs_tuple, (tuple, list)): + for t,x in enumerate(obs_tuple): + self.buf_obs[t][i] = x + else: + self.buf_obs[0][i] = obs_tuple + return self.buf_obs def close(self): return \ No newline at end of file diff --git a/baselines/her/__init__.py b/baselines/her/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/baselines/her/actor_critic.py b/baselines/her/actor_critic.py new file mode 100644 index 0000000..d5443fe --- /dev/null +++ b/baselines/her/actor_critic.py @@ -0,0 +1,44 @@ +import tensorflow as tf +from baselines.her.util import store_args, nn + + +class ActorCritic: + @store_args + def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers, + **kwargs): + """The actor-critic network and related training code. + + Args: + inputs_tf (dict of tensors): all necessary inputs for the network: the + observation (o), the goal (g), and the action (u) + dimo (int): the dimension of the observations + dimg (int): the dimension of the goals + dimu (int): the dimension of the actions + max_u (float): the maximum magnitude of actions; action outputs will be scaled + accordingly + o_stats (baselines.her.Normalizer): normalizer for observations + g_stats (baselines.her.Normalizer): normalizer for goals + hidden (int): number of hidden units that should be used in hidden layers + layers (int): number of hidden layers + """ + self.o_tf = inputs_tf['o'] + self.g_tf = inputs_tf['g'] + self.u_tf = inputs_tf['u'] + + # Prepare inputs for actor and critic. + o = self.o_stats.normalize(self.o_tf) + g = self.g_stats.normalize(self.g_tf) + input_pi = tf.concat(axis=1, values=[o, g]) # for actor + + # Networks. + with tf.variable_scope('pi'): + self.pi_tf = self.max_u * tf.tanh(nn( + input_pi, [self.hidden] * self.layers + [self.dimu])) + with tf.variable_scope('Q'): + # for policy training + input_Q = tf.concat(axis=1, values=[o, g, self.pi_tf / self.max_u]) + self.Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1]) + # for critic training + input_Q = tf.concat(axis=1, values=[o, g, self.u_tf / self.max_u]) + self._input_Q = input_Q # exposed for tests + self.Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True) diff --git a/baselines/her/ddpg.py b/baselines/her/ddpg.py new file mode 100644 index 0000000..92165de --- /dev/null +++ b/baselines/her/ddpg.py @@ -0,0 +1,340 @@ +from collections import OrderedDict + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.staging import StagingArea + +from baselines import logger +from baselines.her.util import ( + import_function, store_args, flatten_grads, transitions_in_episode_batch) +from baselines.her.normalizer import Normalizer +from baselines.her.replay_buffer import ReplayBuffer +from baselines.common.mpi_adam import MpiAdam + + +def dims_to_shapes(input_dims): + return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()} + + +class DDPG(object): + @store_args + def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size, + Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T, + rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return, + sample_transitions, gamma, reuse=False, **kwargs): + """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER). + + Args: + input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the + actions (u) + buffer_size (int): number of transitions that are stored in the replay buffer + hidden (int): number of units in the hidden layers + layers (int): number of hidden layers + network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic') + polyak (float): coefficient for Polyak-averaging of the target network + batch_size (int): batch size for training + Q_lr (float): learning rate for the Q (critic) network + pi_lr (float): learning rate for the pi (actor) network + norm_eps (float): a small value used in the normalizer to avoid numerical instabilities + norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip] + max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u] + action_l2 (float): coefficient for L2 penalty on the actions + clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs] + scope (str): the scope used for the TensorFlow graph + T (int): the time horizon for rollouts + rollout_batch_size (int): number of parallel rollouts per DDPG agent + subtract_goals (function): function that subtracts goals from each other + relative_goals (boolean): whether or not relative goals should be fed into the network + clip_pos_returns (boolean): whether or not positive returns should be clipped + clip_return (float): clip returns to be in [-clip_return, clip_return] + sample_transitions (function) function that samples from the replay buffer + gamma (float): gamma used for Q learning updates + reuse (boolean): whether or not the networks should be reused + """ + if self.clip_return is None: + self.clip_return = np.inf + + self.create_actor_critic = import_function(self.network_class) + + input_shapes = dims_to_shapes(self.input_dims) + self.dimo = self.input_dims['o'] + self.dimg = self.input_dims['g'] + self.dimu = self.input_dims['u'] + + # Prepare staging area for feeding data to the model. + stage_shapes = OrderedDict() + for key in sorted(self.input_dims.keys()): + if key.startswith('info_'): + continue + stage_shapes[key] = (None, *input_shapes[key]) + for key in ['o', 'g']: + stage_shapes[key + '_2'] = stage_shapes[key] + stage_shapes['r'] = (None,) + self.stage_shapes = stage_shapes + + # Create network. + with tf.variable_scope(self.scope): + self.staging_tf = StagingArea( + dtypes=[tf.float32 for _ in self.stage_shapes.keys()], + shapes=list(self.stage_shapes.values())) + self.buffer_ph_tf = [ + tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()] + self.stage_op = self.staging_tf.put(self.buffer_ph_tf) + + self._create_network(reuse=reuse) + + # Configure the replay buffer. + buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key]) + for key, val in input_shapes.items()} + buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg) + buffer_shapes['ag'] = (self.T+1, self.dimg) + + buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size + self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) + + def _random_action(self, n): + return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu)) + + def _preprocess_og(self, o, ag, g): + if self.relative_goals: + g_shape = g.shape + g = g.reshape(-1, self.dimg) + ag = ag.reshape(-1, self.dimg) + g = self.subtract_goals(g, ag) + g = g.reshape(*g_shape) + o = np.clip(o, -self.clip_obs, self.clip_obs) + g = np.clip(g, -self.clip_obs, self.clip_obs) + return o, g + + def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False, + compute_Q=False): + o, g = self._preprocess_og(o, ag, g) + policy = self.target if use_target_net else self.main + # values to compute + vals = [policy.pi_tf] + if compute_Q: + vals += [policy.Q_pi_tf] + # feed + feed = { + policy.o_tf: o.reshape(-1, self.dimo), + policy.g_tf: g.reshape(-1, self.dimg), + policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32) + } + + ret = self.sess.run(vals, feed_dict=feed) + # action postprocessing + u = ret[0] + noise = noise_eps * self.max_u * np.random.randn(*u.shape) # gaussian noise + u += noise + u = np.clip(u, -self.max_u, self.max_u) + u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u) # eps-greedy + if u.shape[0] == 1: + u = u[0] + u = u.copy() + ret[0] = u + + if len(ret) == 1: + return ret[0] + else: + return ret + + def store_episode(self, episode_batch, update_stats=True): + """ + episode_batch: array of batch_size x (T or T+1) x dim_key + 'o' is of size T+1, others are of size T + """ + + self.buffer.store_episode(episode_batch) + + if update_stats: + # add transitions to normalizer + episode_batch['o_2'] = episode_batch['o'][:, 1:, :] + episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] + num_normalizing_transitions = transitions_in_episode_batch(episode_batch) + transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) + + o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag'] + transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) + # No need to preprocess the o_2 and g_2 since this is only used for stats + + self.o_stats.update(transitions['o']) + self.g_stats.update(transitions['g']) + + self.o_stats.recompute_stats() + self.g_stats.recompute_stats() + + def get_current_buffer_size(self): + return self.buffer.get_current_size() + + def _sync_optimizers(self): + self.Q_adam.sync() + self.pi_adam.sync() + + def _grads(self): + # Avoid feed_dict here for performance! + critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([ + self.Q_loss_tf, + self.main.Q_pi_tf, + self.Q_grad_tf, + self.pi_grad_tf + ]) + return critic_loss, actor_loss, Q_grad, pi_grad + + def _update(self, Q_grad, pi_grad): + self.Q_adam.update(Q_grad, self.Q_lr) + self.pi_adam.update(pi_grad, self.pi_lr) + + def sample_batch(self): + transitions = self.buffer.sample(self.batch_size) + o, o_2, g = transitions['o'], transitions['o_2'], transitions['g'] + ag, ag_2 = transitions['ag'], transitions['ag_2'] + transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) + transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g) + + transitions_batch = [transitions[key] for key in self.stage_shapes.keys()] + return transitions_batch + + def stage_batch(self, batch=None): + if batch is None: + batch = self.sample_batch() + assert len(self.buffer_ph_tf) == len(batch) + self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch))) + + def train(self, stage=True): + if stage: + self.stage_batch() + critic_loss, actor_loss, Q_grad, pi_grad = self._grads() + self._update(Q_grad, pi_grad) + return critic_loss, actor_loss + + def _init_target_net(self): + self.sess.run(self.init_target_net_op) + + def update_target_net(self): + self.sess.run(self.update_target_net_op) + + def clear_buffer(self): + self.buffer.clear_buffer() + + def _vars(self, scope): + res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope) + assert len(res) > 0 + return res + + def _global_vars(self, scope): + res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope) + return res + + def _create_network(self, reuse=False): + logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u)) + + self.sess = tf.get_default_session() + if self.sess is None: + self.sess = tf.InteractiveSession() + + # running averages + with tf.variable_scope('o_stats') as vs: + if reuse: + vs.reuse_variables() + self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess) + with tf.variable_scope('g_stats') as vs: + if reuse: + vs.reuse_variables() + self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess) + + # mini-batch sampling. + batch = self.staging_tf.get() + batch_tf = OrderedDict([(key, batch[i]) + for i, key in enumerate(self.stage_shapes.keys())]) + batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1]) + + # networks + with tf.variable_scope('main') as vs: + if reuse: + vs.reuse_variables() + self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__) + vs.reuse_variables() + with tf.variable_scope('target') as vs: + if reuse: + vs.reuse_variables() + target_batch_tf = batch_tf.copy() + target_batch_tf['o'] = batch_tf['o_2'] + target_batch_tf['g'] = batch_tf['g_2'] + self.target = self.create_actor_critic( + target_batch_tf, net_type='target', **self.__dict__) + vs.reuse_variables() + assert len(self._vars("main")) == len(self._vars("target")) + + # loss functions + target_Q_pi_tf = self.target.Q_pi_tf + clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf) + target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range) + self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf)) + self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf) + self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u)) + Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q')) + pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi')) + assert len(self._vars('main/Q')) == len(Q_grads_tf) + assert len(self._vars('main/pi')) == len(pi_grads_tf) + self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q')) + self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi')) + self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q')) + self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi')) + + # optimizers + self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False) + self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False) + + # polyak averaging + self.main_vars = self._vars('main/Q') + self._vars('main/pi') + self.target_vars = self._vars('target/Q') + self._vars('target/pi') + self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats') + self.init_target_net_op = list( + map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars))) + self.update_target_net_op = list( + map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars))) + + # initialize all variables + tf.variables_initializer(self._global_vars('')).run() + self._sync_optimizers() + self._init_target_net() + + def logs(self, prefix=''): + logs = [] + logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))] + logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))] + logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))] + logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))] + + if prefix is not '' and not prefix.endswith('/'): + return [(prefix + '/' + key, val) for key, val in logs] + else: + return logs + + def __getstate__(self): + """Our policies can be loaded from pkl, but after unpickling you cannot continue training. + """ + excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', + 'main', 'target', 'lock', 'env', 'sample_transitions', + 'stage_shapes', 'create_actor_critic'] + + state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])} + state['buffer_size'] = self.buffer_size + state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name]) + return state + + def __setstate__(self, state): + if 'sample_transitions' not in state: + # We don't need this for playing the policy. + state['sample_transitions'] = None + + self.__init__(**state) + # set up stats (they are overwritten in __init__) + for k, v in state.items(): + if k[-6:] == '_stats': + self.__dict__[k] = v + # load TF variables + vars = [x for x in self._global_vars('') if 'buffer' not in x.name] + assert(len(vars) == len(state["tf"])) + node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])] + self.sess.run(node) diff --git a/baselines/her/experiment/__init__.py b/baselines/her/experiment/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/baselines/her/experiment/config.py b/baselines/her/experiment/config.py new file mode 100644 index 0000000..d64211b --- /dev/null +++ b/baselines/her/experiment/config.py @@ -0,0 +1,170 @@ +from copy import deepcopy +import numpy as np +import json +import os +import gym + +from baselines import logger +from baselines.her.ddpg import DDPG +from baselines.her.her import make_sample_her_transitions + + +DEFAULT_ENV_PARAMS = { + 'FetchReach-v0': { + 'n_cycles': 10, + }, +} + + +DEFAULT_PARAMS = { + # env + 'max_u': 1., # max absolute value of actions on different coordinates + # ddpg + 'layers': 3, # number of layers in the critic/actor networks + 'hidden': 256, # number of neurons in each hidden layers + 'network_class': 'baselines.her.actor_critic:ActorCritic', + 'Q_lr': 0.001, # critic learning rate + 'pi_lr': 0.001, # actor learning rate + 'buffer_size': int(1E6), # for experience replay + 'polyak': 0.95, # polyak averaging coefficient + 'action_l2': 1.0, # quadratic penalty on actions (before rescaling by max_u) + 'clip_obs': 200., + 'scope': 'ddpg', # can be tweaked for testing + 'relative_goals': False, + # training + 'n_cycles': 50, # per epoch + 'rollout_batch_size': 2, # per mpi thread + 'n_batches': 40, # training batches per cycle + 'batch_size': 256, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length. + 'n_test_rollouts': 10, # number of test rollouts per epoch, each consists of rollout_batch_size rollouts + 'test_with_polyak': False, # run test episodes with the target network + # exploration + 'random_eps': 0.3, # percentage of time a random action is taken + 'noise_eps': 0.2, # std of gaussian noise added to not-completely-random actions as a percentage of max_u + # HER + 'replay_strategy': 'future', # supported modes: future, none + 'replay_k': 4, # number of additional goals used for replay, only used if off_policy_data=future + # normalization + 'norm_eps': 0.01, # epsilon used for observation normalization + 'norm_clip': 5, # normalized observations are cropped to this values +} + + +CACHED_ENVS = {} +def cached_make_env(make_env): + """ + Only creates a new environment from the provided function if one has not yet already been + created. This is useful here because we need to infer certain properties of the env, e.g. + its observation and action spaces, without any intend of actually using it. + """ + if make_env not in CACHED_ENVS: + env = make_env() + CACHED_ENVS[make_env] = env + return CACHED_ENVS[make_env] + + +def prepare_params(kwargs): + # DDPG params + ddpg_params = dict() + + env_name = kwargs['env_name'] + def make_env(): + return gym.make(env_name) + kwargs['make_env'] = make_env + tmp_env = cached_make_env(kwargs['make_env']) + assert hasattr(tmp_env, '_max_episode_steps') + kwargs['T'] = tmp_env._max_episode_steps + tmp_env.reset() + kwargs['max_u'] = np.array(kwargs['max_u']) if type(kwargs['max_u']) == list else kwargs['max_u'] + kwargs['gamma'] = 1. - 1. / kwargs['T'] + if 'lr' in kwargs: + kwargs['pi_lr'] = kwargs['lr'] + kwargs['Q_lr'] = kwargs['lr'] + del kwargs['lr'] + for name in ['buffer_size', 'hidden', 'layers', + 'network_class', + 'polyak', + 'batch_size', 'Q_lr', 'pi_lr', + 'norm_eps', 'norm_clip', 'max_u', + 'action_l2', 'clip_obs', 'scope', 'relative_goals']: + ddpg_params[name] = kwargs[name] + kwargs['_' + name] = kwargs[name] + del kwargs[name] + kwargs['ddpg_params'] = ddpg_params + + return kwargs + + +def log_params(params, logger=logger): + for key in sorted(params.keys()): + logger.info('{}: {}'.format(key, params[key])) + + +def configure_her(params): + env = cached_make_env(params['make_env']) + env.reset() + def reward_fun(ag_2, g, info): # vectorized + return env.compute_reward(achieved_goal=ag_2, desired_goal=g, info=info) + + # Prepare configuration for HER. + her_params = { + 'reward_fun': reward_fun, + } + for name in ['replay_strategy', 'replay_k']: + her_params[name] = params[name] + params['_' + name] = her_params[name] + del params[name] + sample_her_transitions = make_sample_her_transitions(**her_params) + + return sample_her_transitions + + +def simple_goal_subtract(a, b): + assert a.shape == b.shape + return a - b + + +def configure_ddpg(dims, params, reuse=False, use_mpi=True, clip_return=True): + sample_her_transitions = configure_her(params) + # Extract relevant parameters. + gamma = params['gamma'] + rollout_batch_size = params['rollout_batch_size'] + ddpg_params = params['ddpg_params'] + + input_dims = dims.copy() + + # DDPG agent + env = cached_make_env(params['make_env']) + env.reset() + ddpg_params.update({'input_dims': input_dims, # agent takes an input observations + 'T': params['T'], + 'clip_pos_returns': True, # clip positive returns + 'clip_return': (1. / (1. - gamma)) if clip_return else np.inf, # max abs of return + 'rollout_batch_size': rollout_batch_size, + 'subtract_goals': simple_goal_subtract, + 'sample_transitions': sample_her_transitions, + 'gamma': gamma, + }) + ddpg_params['info'] = { + 'env_name': params['env_name'], + } + policy = DDPG(reuse=reuse, **ddpg_params, use_mpi=use_mpi) + return policy + + +def configure_dims(params): + env = cached_make_env(params['make_env']) + env.reset() + obs, _, _, info = env.step(env.action_space.sample()) + + dims = { + 'o': obs['observation'].shape[0], + 'u': env.action_space.shape[0], + 'g': obs['desired_goal'].shape[0], + } + for key, value in info.items(): + value = np.array(value) + if value.ndim == 0: + value = value.reshape(1) + dims['info_{}'.format(key)] = value.shape[0] + return dims diff --git a/baselines/her/experiment/play.py b/baselines/her/experiment/play.py new file mode 100644 index 0000000..5b2f85d --- /dev/null +++ b/baselines/her/experiment/play.py @@ -0,0 +1,60 @@ +import click +import numpy as np +import pickle + +from baselines import logger +from baselines.common import set_global_seeds +import baselines.her.experiment.config as config +from baselines.her.rollout import RolloutWorker + + +@click.command() +@click.argument('policy_file', type=str) +@click.option('--seed', type=int, default=0) +@click.option('--n_test_rollouts', type=int, default=10) +@click.option('--render', type=int, default=1) +def main(policy_file, seed, n_test_rollouts, render): + set_global_seeds(seed) + + # Load policy. + with open(policy_file, 'rb') as f: + policy = pickle.load(f) + env_name = policy.info['env_name'] + + # Prepare params. + params = config.DEFAULT_PARAMS + if env_name in config.DEFAULT_ENV_PARAMS: + params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in + params['env_name'] = env_name + params = config.prepare_params(params) + config.log_params(params, logger=logger) + + dims = config.configure_dims(params) + + eval_params = { + 'exploit': True, + 'use_target_net': params['test_with_polyak'], + 'compute_Q': True, + 'rollout_batch_size': 1, + 'render': bool(render), + } + + for name in ['T', 'gamma', 'noise_eps', 'random_eps']: + eval_params[name] = params[name] + + evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params) + evaluator.seed(seed) + + # Run evaluation. + evaluator.clear_history() + for _ in range(n_test_rollouts): + evaluator.generate_rollouts() + + # record logs + for key, val in evaluator.logs('test'): + logger.record_tabular(key, np.mean(val)) + logger.dump_tabular() + + +if __name__ == '__main__': + main() diff --git a/baselines/her/experiment/plot.py b/baselines/her/experiment/plot.py new file mode 100644 index 0000000..694db4b --- /dev/null +++ b/baselines/her/experiment/plot.py @@ -0,0 +1,118 @@ +import os +import matplotlib.pyplot as plt +import numpy as np +import json +import seaborn as sns; sns.set() +import glob2 +import argparse + + +def smooth_reward_curve(x, y): + halfwidth = int(np.ceil(len(x) / 60)) # Halfwidth of our smoothing convolution + k = halfwidth + xsmoo = x + ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='same') / np.convolve(np.ones_like(y), np.ones(2 * k + 1), + mode='same') + return xsmoo, ysmoo + + +def load_results(file): + if not os.path.exists(file): + return None + with open(file, 'r') as f: + lines = [line for line in f] + if len(lines) < 2: + return None + keys = [name.strip() for name in lines[0].split(',')] + data = np.genfromtxt(file, delimiter=',', skip_header=1, filling_values=0.) + if data.ndim == 1: + data = data.reshape(1, -1) + assert data.ndim == 2 + assert data.shape[-1] == len(keys) + result = {} + for idx, key in enumerate(keys): + result[key] = data[:, idx] + return result + + +def pad(xs, value=np.nan): + maxlen = np.max([len(x) for x in xs]) + + padded_xs = [] + for x in xs: + if x.shape[0] >= maxlen: + padded_xs.append(x) + + padding = np.ones((maxlen - x.shape[0],) + x.shape[1:]) * value + x_padded = np.concatenate([x, padding], axis=0) + assert x_padded.shape[1:] == x.shape[1:] + assert x_padded.shape[0] == maxlen + padded_xs.append(x_padded) + return np.array(padded_xs) + + +parser = argparse.ArgumentParser() +parser.add_argument('dir', type=str) +parser.add_argument('--smooth', type=int, default=1) +args = parser.parse_args() + +# Load all data. +data = {} +paths = [os.path.abspath(os.path.join(path, '..')) for path in glob2.glob(os.path.join(args.dir, '**', 'progress.csv'))] +for curr_path in paths: + if not os.path.isdir(curr_path): + continue + results = load_results(os.path.join(curr_path, 'progress.csv')) + if not results: + print('skipping {}'.format(curr_path)) + continue + print('loading {} ({})'.format(curr_path, len(results['epoch']))) + with open(os.path.join(curr_path, 'metadata.json'), 'r') as f: + metadata = json.load(f) + + success_rate = np.array(results['test/success_rate']) + epoch = np.array(results['epoch']) + 1 + env_id = metadata['kwargs']['env_name'] + replay_strategy = metadata['kwargs']['replay_strategy'] + + if replay_strategy == 'future': + config = 'her' + else: + config = 'ddpg' + if 'Dense' in env_id: + config += '-dense' + else: + config += '-sparse' + env_id = env_id.replace('Dense', '') + + # Process and smooth data. + assert success_rate.shape == epoch.shape + x = epoch + y = success_rate + if args.smooth: + x, y = smooth_reward_curve(epoch, success_rate) + assert x.shape == y.shape + + if env_id not in data: + data[env_id] = {} + if config not in data[env_id]: + data[env_id][config] = [] + data[env_id][config].append((x, y)) + +# Plot data. +for env_id in sorted(data.keys()): + print('exporting {}'.format(env_id)) + plt.clf() + + for config in sorted(data[env_id].keys()): + xs, ys = zip(*data[env_id][config]) + xs, ys = pad(xs), pad(ys) + assert xs.shape == ys.shape + + plt.plot(xs[0], np.nanmedian(ys, axis=0), label=config) + plt.fill_between(xs[0], np.nanpercentile(ys, 25, axis=0), np.nanpercentile(ys, 75, axis=0), alpha=0.25) + plt.title(env_id) + plt.xlabel('Epoch') + plt.ylabel('Median Success Rate') + plt.legend() + plt.savefig(os.path.join(args.dir, 'fig_{}.png'.format(env_id))) diff --git a/baselines/her/experiment/train.py b/baselines/her/experiment/train.py new file mode 100644 index 0000000..c8d2405 --- /dev/null +++ b/baselines/her/experiment/train.py @@ -0,0 +1,169 @@ +import os +import sys + +import click +import numpy as np +import json +from mpi4py import MPI + +from baselines import logger +from baselines.common import set_global_seeds +from baselines.common.mpi_moments import mpi_moments +import baselines.her.experiment.config as config +from baselines.her.rollout import RolloutWorker +from baselines.her.util import mpi_fork + + +def mpi_average(value): + if value == []: + value = [0.] + if not isinstance(value, list): + value = [value] + return mpi_moments(np.array(value))[0] + + +def train(policy, rollout_worker, evaluator, + n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval, + save_policies, **kwargs): + rank = MPI.COMM_WORLD.Get_rank() + + latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl') + best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl') + periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl') + + logger.info("Training...") + best_success_rate = -1 + for epoch in range(n_epochs): + # train + rollout_worker.clear_history() + for _ in range(n_cycles): + episode = rollout_worker.generate_rollouts() + policy.store_episode(episode) + for _ in range(n_batches): + policy.train() + policy.update_target_net() + + # test + evaluator.clear_history() + for _ in range(n_test_rollouts): + evaluator.generate_rollouts() + + # record logs + logger.record_tabular('epoch', epoch) + for key, val in evaluator.logs('test'): + logger.record_tabular(key, mpi_average(val)) + for key, val in rollout_worker.logs('train'): + logger.record_tabular(key, mpi_average(val)) + for key, val in policy.logs(): + logger.record_tabular(key, mpi_average(val)) + + if rank == 0: + logger.dump_tabular() + + # save the policy if it's better than the previous ones + success_rate = mpi_average(evaluator.current_success_rate()) + if rank == 0 and success_rate >= best_success_rate and save_policies: + best_success_rate = success_rate + logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path)) + evaluator.save_policy(best_policy_path) + evaluator.save_policy(latest_policy_path) + if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies: + policy_path = periodic_policy_path.format(epoch) + logger.info('Saving periodic policy to {} ...'.format(policy_path)) + evaluator.save_policy(policy_path) + + # make sure that different threads have different seeds + local_uniform = np.random.uniform(size=(1,)) + root_uniform = local_uniform.copy() + MPI.COMM_WORLD.Bcast(root_uniform, root=0) + if rank != 0: + assert local_uniform[0] != root_uniform[0] + + +def launch( + env_name, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return, + override_params={}, save_policies=True +): + # Fork for multi-CPU MPI implementation. + if num_cpu > 1: + whoami = mpi_fork(num_cpu) + if whoami == 'parent': + sys.exit(0) + import baselines.common.tf_util as U + U.single_threaded_session().__enter__() + rank = MPI.COMM_WORLD.Get_rank() + + # Configure logging + if rank == 0 and (logdir or logger.get_dir() is None): + logger.configure(dir=logdir) + logdir = logger.get_dir() + os.makedirs(logdir, exist_ok=True) + assert logger.get_dir() is not None + + # Seed everything. + rank_seed = seed + 1000000 * rank + set_global_seeds(rank_seed) + + # Prepare params. + params = config.DEFAULT_PARAMS + params['env_name'] = env_name + params['replay_strategy'] = replay_strategy + if env_name in config.DEFAULT_ENV_PARAMS: + params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in + params.update(**override_params) # makes it possible to override any parameter + with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f: + json.dump(params, f) + params = config.prepare_params(params) + config.log_params(params, logger=logger) + + dims = config.configure_dims(params) + policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return) + + rollout_params = { + 'exploit': False, + 'use_target_net': False, + 'use_demo_states': True, + 'compute_Q': False, + 'T': params['T'], + } + + eval_params = { + 'exploit': True, + 'use_target_net': params['test_with_polyak'], + 'use_demo_states': False, + 'compute_Q': True, + 'T': params['T'], + } + + for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']: + rollout_params[name] = params[name] + eval_params[name] = params[name] + + rollout_worker = RolloutWorker(params['make_env'], policy, dims, logger, **rollout_params) + rollout_worker.seed(rank_seed) + + evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params) + evaluator.seed(rank_seed) + + train( + logdir=logdir, policy=policy, rollout_worker=rollout_worker, + evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'], + n_cycles=params['n_cycles'], n_batches=params['n_batches'], + policy_save_interval=policy_save_interval, save_policies=save_policies) + + +@click.command() +@click.option('--env_name', type=str, default='FetchReach-v0') +@click.option('--logdir', type=str, default=None) +@click.option('--n_epochs', type=int, default=50) +@click.option('--num_cpu', type=int, default=1) +@click.option('--seed', type=int, default=0) +@click.option('--policy_save_interval', type=int, default=5) +@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future') +@click.option('--clip_return', type=int, default=1) +def main(**kwargs): + launch(**kwargs) + + +if __name__ == '__main__': + main() diff --git a/baselines/her/her.py b/baselines/her/her.py new file mode 100644 index 0000000..76f3c34 --- /dev/null +++ b/baselines/her/her.py @@ -0,0 +1,63 @@ +import numpy as np + + +def make_sample_her_transitions(replay_strategy, replay_k, reward_fun): + """Creates a sample function that can be used for HER experience replay. + + Args: + replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none', + regular DDPG experience replay is used + replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times + as many HER replays as regular replays are used) + reward_fun (function): function to re-compute the reward with substituted goals + """ + if replay_strategy == 'future': + future_p = 1 - (1. / (1 + replay_k)) + else: # 'replay_strategy' == 'none' + future_p = 0 + + def _sample_her_transitions(episode_batch, batch_size_in_transitions): + """episode_batch is {key: array(buffer_size x T x dim_key)} + """ + T = episode_batch['u'].shape[1] + rollout_batch_size = episode_batch['u'].shape[0] + batch_size = batch_size_in_transitions + + # Select which episodes and time steps to use. + episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) + t_samples = np.random.randint(T, size=batch_size) + transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() + for key in episode_batch.keys()} + + # Select future time indexes proportional with probability future_p. These + # will be used for HER replay by substituting in future goals. + her_indexes = np.where(np.random.uniform(size=batch_size) < future_p) + future_offset = np.random.uniform(size=batch_size) * (T - t_samples) + future_offset = future_offset.astype(int) + future_t = (t_samples + 1 + future_offset)[her_indexes] + + # Replace goal with achieved goal but only for the previously-selected + # HER transitions (as defined by her_indexes). For the other transitions, + # keep the original goal. + future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] + transitions['g'][her_indexes] = future_ag + + # Reconstruct info dictionary for reward computation. + info = {} + for key, value in transitions.items(): + if key.startswith('info_'): + info[key.replace('info_', '')] = value + + # Re-compute reward since we may have substituted the goal. + reward_params = {k: transitions[k] for k in ['ag_2', 'g']} + reward_params['info'] = info + transitions['r'] = reward_fun(**reward_params) + + transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) + for k in transitions.keys()} + + assert(transitions['u'].shape[0] == batch_size_in_transitions) + + return transitions + + return _sample_her_transitions diff --git a/baselines/her/normalizer.py b/baselines/her/normalizer.py new file mode 100644 index 0000000..d2b0588 --- /dev/null +++ b/baselines/her/normalizer.py @@ -0,0 +1,140 @@ +import threading + +import numpy as np +from mpi4py import MPI +import tensorflow as tf + +from baselines.her.util import reshape_for_broadcasting + + +class Normalizer: + def __init__(self, size, eps=1e-2, default_clip_range=np.inf, sess=None): + """A normalizer that ensures that observations are approximately distributed according to + a standard Normal distribution (i.e. have mean zero and variance one). + + Args: + size (int): the size of the observation to be normalized + eps (float): a small constant that avoids underflows + default_clip_range (float): normalized observations are clipped to be in + [-default_clip_range, default_clip_range] + sess (object): the TensorFlow session to be used + """ + self.size = size + self.eps = eps + self.default_clip_range = default_clip_range + self.sess = sess if sess is not None else tf.get_default_session() + + self.local_sum = np.zeros(self.size, np.float32) + self.local_sumsq = np.zeros(self.size, np.float32) + self.local_count = np.zeros(1, np.float32) + + self.sum_tf = tf.get_variable( + initializer=tf.zeros_initializer(), shape=self.local_sum.shape, name='sum', + trainable=False, dtype=tf.float32) + self.sumsq_tf = tf.get_variable( + initializer=tf.zeros_initializer(), shape=self.local_sumsq.shape, name='sumsq', + trainable=False, dtype=tf.float32) + self.count_tf = tf.get_variable( + initializer=tf.ones_initializer(), shape=self.local_count.shape, name='count', + trainable=False, dtype=tf.float32) + self.mean = tf.get_variable( + initializer=tf.zeros_initializer(), shape=(self.size,), name='mean', + trainable=False, dtype=tf.float32) + self.std = tf.get_variable( + initializer=tf.ones_initializer(), shape=(self.size,), name='std', + trainable=False, dtype=tf.float32) + self.count_pl = tf.placeholder(name='count_pl', shape=(1,), dtype=tf.float32) + self.sum_pl = tf.placeholder(name='sum_pl', shape=(self.size,), dtype=tf.float32) + self.sumsq_pl = tf.placeholder(name='sumsq_pl', shape=(self.size,), dtype=tf.float32) + + self.update_op = tf.group( + self.count_tf.assign_add(self.count_pl), + self.sum_tf.assign_add(self.sum_pl), + self.sumsq_tf.assign_add(self.sumsq_pl) + ) + self.recompute_op = tf.group( + tf.assign(self.mean, self.sum_tf / self.count_tf), + tf.assign(self.std, tf.sqrt(tf.maximum( + tf.square(self.eps), + self.sumsq_tf / self.count_tf - tf.square(self.sum_tf / self.count_tf) + ))), + ) + self.lock = threading.Lock() + + def update(self, v): + v = v.reshape(-1, self.size) + + with self.lock: + self.local_sum += v.sum(axis=0) + self.local_sumsq += (np.square(v)).sum(axis=0) + self.local_count[0] += v.shape[0] + + def normalize(self, v, clip_range=None): + if clip_range is None: + clip_range = self.default_clip_range + mean = reshape_for_broadcasting(self.mean, v) + std = reshape_for_broadcasting(self.std, v) + return tf.clip_by_value((v - mean) / std, -clip_range, clip_range) + + def denormalize(self, v): + mean = reshape_for_broadcasting(self.mean, v) + std = reshape_for_broadcasting(self.std, v) + return mean + v * std + + def _mpi_average(self, x): + buf = np.zeros_like(x) + MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM) + buf /= MPI.COMM_WORLD.Get_size() + return buf + + def synchronize(self, local_sum, local_sumsq, local_count, root=None): + local_sum[...] = self._mpi_average(local_sum) + local_sumsq[...] = self._mpi_average(local_sumsq) + local_count[...] = self._mpi_average(local_count) + return local_sum, local_sumsq, local_count + + def recompute_stats(self): + with self.lock: + # Copy over results. + local_count = self.local_count.copy() + local_sum = self.local_sum.copy() + local_sumsq = self.local_sumsq.copy() + + # Reset. + self.local_count[...] = 0 + self.local_sum[...] = 0 + self.local_sumsq[...] = 0 + + # We perform the synchronization outside of the lock to keep the critical section as short + # as possible. + synced_sum, synced_sumsq, synced_count = self.synchronize( + local_sum=local_sum, local_sumsq=local_sumsq, local_count=local_count) + + self.sess.run(self.update_op, feed_dict={ + self.count_pl: synced_count, + self.sum_pl: synced_sum, + self.sumsq_pl: synced_sumsq, + }) + self.sess.run(self.recompute_op) + + +class IdentityNormalizer: + def __init__(self, size, std=1.): + self.size = size + self.mean = tf.zeros(self.size, tf.float32) + self.std = std * tf.ones(self.size, tf.float32) + + def update(self, x): + pass + + def normalize(self, x, clip_range=None): + return x / self.std + + def denormalize(self, x): + return self.std * x + + def synchronize(self): + pass + + def recompute_stats(self): + pass diff --git a/baselines/her/replay_buffer.py b/baselines/her/replay_buffer.py new file mode 100644 index 0000000..b000552 --- /dev/null +++ b/baselines/her/replay_buffer.py @@ -0,0 +1,108 @@ +import threading + +import numpy as np + + +class ReplayBuffer: + def __init__(self, buffer_shapes, size_in_transitions, T, sample_transitions): + """Creates a replay buffer. + + Args: + buffer_shapes (dict of ints): the shape for all buffers that are used in the replay + buffer + size_in_transitions (int): the size of the buffer, measured in transitions + T (int): the time horizon for episodes + sample_transitions (function): a function that samples from the replay buffer + """ + self.buffer_shapes = buffer_shapes + self.size = size_in_transitions // T + self.T = T + self.sample_transitions = sample_transitions + + # self.buffers is {key: array(size_in_episodes x T or T+1 x dim_key)} + self.buffers = {key: np.empty([self.size, *shape]) + for key, shape in buffer_shapes.items()} + + # memory management + self.current_size = 0 + self.n_transitions_stored = 0 + + self.lock = threading.Lock() + + @property + def full(self): + with self.lock: + return self.current_size == self.size + + def sample(self, batch_size): + """Returns a dict {key: array(batch_size x shapes[key])} + """ + buffers = {} + + with self.lock: + assert self.current_size > 0 + for key in self.buffers.keys(): + buffers[key] = self.buffers[key][:self.current_size] + + buffers['o_2'] = buffers['o'][:, 1:, :] + buffers['ag_2'] = buffers['ag'][:, 1:, :] + + transitions = self.sample_transitions(buffers, batch_size) + + for key in (['r', 'o_2', 'ag_2'] + list(self.buffers.keys())): + assert key in transitions, "key %s missing from transitions" % key + + return transitions + + def store_episode(self, episode_batch): + """episode_batch: array(batch_size x (T or T+1) x dim_key) + """ + batch_sizes = [len(episode_batch[key]) for key in episode_batch.keys()] + assert np.all(np.array(batch_sizes) == batch_sizes[0]) + batch_size = batch_sizes[0] + + with self.lock: + idxs = self._get_storage_idx(batch_size) + + # load inputs into buffers + for key in self.buffers.keys(): + self.buffers[key][idxs] = episode_batch[key] + + self.n_transitions_stored += batch_size * self.T + + def get_current_episode_size(self): + with self.lock: + return self.current_size + + def get_current_size(self): + with self.lock: + return self.current_size * self.T + + def get_transitions_stored(self): + with self.lock: + return self.n_transitions_stored + + def clear_buffer(self): + with self.lock: + self.current_size = 0 + + def _get_storage_idx(self, inc=None): + inc = inc or 1 # size increment + assert inc <= self.size, "Batch committed to replay is too large!" + # go consecutively until you hit the end, and then go randomly. + if self.current_size+inc <= self.size: + idx = np.arange(self.current_size, self.current_size+inc) + elif self.current_size < self.size: + overflow = inc - (self.size - self.current_size) + idx_a = np.arange(self.current_size, self.size) + idx_b = np.random.randint(0, self.current_size, overflow) + idx = np.concatenate([idx_a, idx_b]) + else: + idx = np.random.randint(0, self.size, inc) + + # update replay size + self.current_size = min(self.size, self.current_size+inc) + + if inc == 1: + idx = idx[0] + return idx diff --git a/baselines/her/rollout.py b/baselines/her/rollout.py new file mode 100644 index 0000000..5beba69 --- /dev/null +++ b/baselines/her/rollout.py @@ -0,0 +1,188 @@ +from collections import deque + +import numpy as np +import pickle +from mujoco_py import MujocoException + +from baselines.her.util import convert_episode_to_batch_major, store_args + + +class RolloutWorker: + + @store_args + def __init__(self, make_env, policy, dims, logger, T, rollout_batch_size=1, + exploit=False, use_target_net=False, compute_Q=False, noise_eps=0, + random_eps=0, history_len=100, render=False, **kwargs): + """Rollout worker generates experience by interacting with one or many environments. + + Args: + make_env (function): a factory function that creates a new instance of the environment + when called + policy (object): the policy that is used to act + dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u) + logger (object): the logger that is used by the rollout worker + rollout_batch_size (int): the number of parallel rollouts that should be used + exploit (boolean): whether or not to exploit, i.e. to act optimally according to the + current policy without any exploration + use_target_net (boolean): whether or not to use the target net for rollouts + compute_Q (boolean): whether or not to compute the Q values alongside the actions + noise_eps (float): scale of the additive Gaussian noise + random_eps (float): probability of selecting a completely random action + history_len (int): length of history for statistics smoothing + render (boolean): whether or not to render the rollouts + """ + self.envs = [make_env() for _ in range(rollout_batch_size)] + assert self.T > 0 + + self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')] + + self.success_history = deque(maxlen=history_len) + self.Q_history = deque(maxlen=history_len) + + self.n_episodes = 0 + self.g = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals + self.initial_o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations + self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals + self.reset_all_rollouts() + self.clear_history() + + def reset_rollout(self, i): + """Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o` + and `g` arrays accordingly. + """ + obs = self.envs[i].reset() + self.initial_o[i] = obs['observation'] + self.initial_ag[i] = obs['achieved_goal'] + self.g[i] = obs['desired_goal'] + + def reset_all_rollouts(self): + """Resets all `rollout_batch_size` rollout workers. + """ + for i in range(self.rollout_batch_size): + self.reset_rollout(i) + + def generate_rollouts(self): + """Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current + policy acting on it accordingly. + """ + self.reset_all_rollouts() + + # compute observations + o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations + ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals + o[:] = self.initial_o + ag[:] = self.initial_ag + + # generate episodes + obs, achieved_goals, acts, goals, successes = [], [], [], [], [] + info_values = [np.empty((self.T, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys] + Qs = [] + for t in range(self.T): + policy_output = self.policy.get_actions( + o, ag, self.g, + compute_Q=self.compute_Q, + noise_eps=self.noise_eps if not self.exploit else 0., + random_eps=self.random_eps if not self.exploit else 0., + use_target_net=self.use_target_net) + + if self.compute_Q: + u, Q = policy_output + Qs.append(Q) + else: + u = policy_output + + if u.ndim == 1: + # The non-batched case should still have a reasonable shape. + u = u.reshape(1, -1) + + o_new = np.empty((self.rollout_batch_size, self.dims['o'])) + ag_new = np.empty((self.rollout_batch_size, self.dims['g'])) + success = np.zeros(self.rollout_batch_size) + # compute new states and observations + for i in range(self.rollout_batch_size): + try: + # We fully ignore the reward here because it will have to be re-computed + # for HER. + curr_o_new, _, _, info = self.envs[i].step(u[i]) + if 'is_success' in info: + success[i] = info['is_success'] + o_new[i] = curr_o_new['observation'] + ag_new[i] = curr_o_new['achieved_goal'] + for idx, key in enumerate(self.info_keys): + info_values[idx][t, i] = info[key] + if self.render: + self.envs[i].render() + except MujocoException as e: + return self.generate_rollouts() + + if np.isnan(o_new).any(): + self.logger.warning('NaN caught during rollout generation. Trying again...') + self.reset_all_rollouts() + return self.generate_rollouts() + + obs.append(o.copy()) + achieved_goals.append(ag.copy()) + successes.append(success.copy()) + acts.append(u.copy()) + goals.append(self.g.copy()) + o[...] = o_new + ag[...] = ag_new + obs.append(o.copy()) + achieved_goals.append(ag.copy()) + self.initial_o[:] = o + + episode = dict(o=obs, + u=acts, + g=goals, + ag=achieved_goals) + for key, value in zip(self.info_keys, info_values): + episode['info_{}'.format(key)] = value + + # stats + successful = np.array(successes)[-1, :] + assert successful.shape == (self.rollout_batch_size,) + success_rate = np.mean(successful) + self.success_history.append(success_rate) + if self.compute_Q: + self.Q_history.append(np.mean(Qs)) + self.n_episodes += self.rollout_batch_size + + return convert_episode_to_batch_major(episode) + + def clear_history(self): + """Clears all histories that are used for statistics + """ + self.success_history.clear() + self.Q_history.clear() + + def current_success_rate(self): + return np.mean(self.success_history) + + def current_mean_Q(self): + return np.mean(self.Q_history) + + def save_policy(self, path): + """Pickles the current policy for later inspection. + """ + with open(path, 'wb') as f: + pickle.dump(self.policy, f) + + def logs(self, prefix='worker'): + """Generates a dictionary that contains all collected statistics. + """ + logs = [] + logs += [('success_rate', np.mean(self.success_history))] + if self.compute_Q: + logs += [('mean_Q', np.mean(self.Q_history))] + logs += [('episode', self.n_episodes)] + + if prefix is not '' and not prefix.endswith('/'): + return [(prefix + '/' + key, val) for key, val in logs] + else: + return logs + + def seed(self, seed): + """Seeds each environment with a distinct seed derived from the passed in global seed. + """ + for idx, env in enumerate(self.envs): + env.seed(seed + 1000 * idx) diff --git a/baselines/her/util.py b/baselines/her/util.py new file mode 100644 index 0000000..d79a776 --- /dev/null +++ b/baselines/her/util.py @@ -0,0 +1,144 @@ +import os +import subprocess +import sys +import importlib +import inspect +import functools + +import tensorflow as tf +import numpy as np + +from baselines.common import tf_util as U + + +def store_args(method): + """Stores provided method args as instance attributes. + """ + argspec = inspect.getfullargspec(method) + defaults = {} + if argspec.defaults is not None: + defaults = dict( + zip(argspec.args[-len(argspec.defaults):], argspec.defaults)) + if argspec.kwonlydefaults is not None: + defaults.update(argspec.kwonlydefaults) + arg_names = argspec.args[1:] + + @functools.wraps(method) + def wrapper(*positional_args, **keyword_args): + self = positional_args[0] + # Get default arg values + args = defaults.copy() + # Add provided arg values + for name, value in zip(arg_names, positional_args[1:]): + args[name] = value + args.update(keyword_args) + self.__dict__.update(args) + return method(*positional_args, **keyword_args) + + return wrapper + + +def import_function(spec): + """Import a function identified by a string like "pkg.module:fn_name". + """ + mod_name, fn_name = spec.split(':') + module = importlib.import_module(mod_name) + fn = getattr(module, fn_name) + return fn + + +def flatten_grads(var_list, grads): + """Flattens a variables and their gradients. + """ + return tf.concat([tf.reshape(grad, [U.numel(v)]) + for (v, grad) in zip(var_list, grads)], 0) + + +def nn(input, layers_sizes, reuse=None, flatten=False, name=""): + """Creates a simple neural network + """ + for i, size in enumerate(layers_sizes): + activation = tf.nn.relu if i < len(layers_sizes)-1 else None + input = tf.layers.dense(inputs=input, + units=size, + kernel_initializer=tf.contrib.layers.xavier_initializer(), + reuse=reuse, + name=name+'_'+str(i)) + if activation: + input = activation(input) + if flatten: + assert layers_sizes[-1] == 1 + input = tf.reshape(input, [-1]) + return input + + +def install_mpi_excepthook(): + import sys + from mpi4py import MPI + old_hook = sys.excepthook + + def new_hook(a, b, c): + old_hook(a, b, c) + sys.stdout.flush() + sys.stderr.flush() + MPI.COMM_WORLD.Abort() + sys.excepthook = new_hook + + +def mpi_fork(n): + """Re-launches the current script with workers + Returns "parent" for original parent, "child" for MPI children + """ + if n <= 1: + return "child" + if os.getenv("IN_MPI") is None: + env = os.environ.copy() + env.update( + MKL_NUM_THREADS="1", + OMP_NUM_THREADS="1", + IN_MPI="1" + ) + # "-bind-to core" is crucial for good performance + args = [ + "mpirun", + "-np", + str(n), + "-bind-to", + "core", + sys.executable + ] + args += sys.argv + subprocess.check_call(args, env=env) + return "parent" + else: + install_mpi_excepthook() + return "child" + + +def convert_episode_to_batch_major(episode): + """Converts an episode to have the batch dimension in the major (first) + dimension. + """ + episode_batch = {} + for key in episode.keys(): + val = np.array(episode[key]).copy() + # make inputs batch-major instead of time-major + episode_batch[key] = val.swapaxes(0, 1) + + return episode_batch + + +def transitions_in_episode_batch(episode_batch): + """Number of transitions in a given episode batch. + """ + shape = episode_batch['u'].shape + return shape[0] * shape[1] + + +def reshape_for_broadcasting(source, target): + """Reshapes a tensor (source) to have the correct shape and dtype of the target + before broadcasting it with MPI. + """ + dim = len(target.get_shape()) + shape = ([1] * (dim-1)) + [-1] + return tf.reshape(tf.cast(source, target.dtype), shape) diff --git a/baselines/logger.py b/baselines/logger.py index 79ab6c8..0d24e6f 100644 --- a/baselines/logger.py +++ b/baselines/logger.py @@ -6,6 +6,7 @@ import json import time import datetime import tempfile +from collections import defaultdict LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv'] # Also valid: json, tensorboard @@ -124,7 +125,7 @@ class CSVOutputFormat(KVWriter): if i > 0: self.file.write(',') v = kvs.get(k) - if v: + if v is not None: self.file.write(str(v)) self.file.write('\n') self.file.flush() @@ -168,24 +169,18 @@ class TensorBoardOutputFormat(KVWriter): self.writer.Close() self.writer = None -def make_output_format(format, ev_dir): - from mpi4py import MPI +def make_output_format(format, ev_dir, log_suffix=''): os.makedirs(ev_dir, exist_ok=True) - rank = MPI.COMM_WORLD.Get_rank() if format == 'stdout': return HumanOutputFormat(sys.stdout) elif format == 'log': - suffix = "" if rank==0 else ("-mpi%03i"%rank) - return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % suffix)) + return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) elif format == 'json': - assert rank==0 - return JSONOutputFormat(osp.join(ev_dir, 'progress.json')) + return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) elif format == 'csv': - assert rank==0 - return CSVOutputFormat(osp.join(ev_dir, 'progress.csv')) + return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) elif format == 'tensorboard': - assert rank==0 - return TensorBoardOutputFormat(osp.join(ev_dir, 'tb')) + return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) else: raise ValueError('Unknown format specified: %s' % (format,)) @@ -197,9 +192,16 @@ def logkv(key, val): """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. """ Logger.CURRENT.logkv(key, val) +def logkv_mean(key, val): + """ + The same as logkv(), but if called many times, values averaged. + """ + Logger.CURRENT.logkv_mean(key, val) + def logkvs(d): """ Log a dictionary of key-value pairs @@ -255,6 +257,33 @@ def get_dir(): record_tabular = logkv dump_tabular = dumpkvs +class ProfileKV: + """ + Usage: + with logger.ProfileKV("interesting_scope"): + code + """ + def __init__(self, n): + self.n = "wait_" + n + def __enter__(self): + self.t1 = time.time() + def __exit__(self ,type, value, traceback): + Logger.CURRENT.name2val[self.n] += time.time() - self.t1 + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with ProfileKV(n): + return func(*args, **kwargs) + return func_wrapper + return decorator_with_name + + # ================================================================ # Backend # ================================================================ @@ -265,7 +294,8 @@ class Logger(object): CURRENT = None # Current logger being used by the free functions above def __init__(self, dir, output_formats): - self.name2val = {} # values this iteration + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) self.level = INFO self.dir = dir self.output_formats = output_formats @@ -275,12 +305,21 @@ class Logger(object): def logkv(self, key, val): self.name2val[key] = val + def logkv_mean(self, key, val): + if val is None: + self.name2val[key] = None + return + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1) + self.name2cnt[key] = cnt + 1 + def dumpkvs(self): if self.level == DISABLED: return for fmt in self.output_formats: if isinstance(fmt, KVWriter): fmt.writekvs(self.name2val) self.name2val.clear() + self.name2cnt.clear() def log(self, *args, level=INFO): if self.level <= level: @@ -360,6 +399,11 @@ def _demo(): logkv("a", 5.5) dumpkvs() info("^^^ should see a = 5.5") + logkv_mean("b", -22.5) + logkv_mean("b", -44.4) + logkv("a", 5.5) + dumpkvs() + info("^^^ should see b = 33.3") logkv("b", -2.5) dumpkvs() diff --git a/setup.py b/setup.py index 4bb492b..f976e37 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup(name='baselines', packages=[package for package in find_packages() if package.startswith('baselines')], install_requires=[ - 'gym[mujoco,atari,classic_control]', + 'gym[mujoco,atari,classic_control,robotics]', 'scipy', 'tqdm', 'joblib', @@ -19,9 +19,11 @@ setup(name='baselines', 'progressbar2', 'mpi4py', 'cloudpickle', + 'tensorflow>=1.4.0', + 'click', ], description='OpenAI baselines: high quality implementations of reinforcement learning algorithms', author='OpenAI', url='https://github.com/openai/baselines', author_email='gym@openai.com', - version='0.1.4') + version='0.1.5')