* make baselines run without mpi wip * squash-merged latest master * further removing MPI references where unnecessary * more MPI removal * syntax and flake8 * MpiAdam becomes regular Adam if Mpi not present * autopep8 * add assertion to test in mpi_adam; fix trpo_mpi failure without MPI on cartpole * mpiless ddpg
403 lines
18 KiB
Python
Executable File
403 lines
18 KiB
Python
Executable File
from copy import copy
|
|
from functools import reduce
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow.contrib as tc
|
|
|
|
from baselines import logger
|
|
from baselines.common.mpi_adam import MpiAdam
|
|
import baselines.common.tf_util as U
|
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
|
try:
|
|
from mpi4py import MPI
|
|
except ImportError:
|
|
MPI = None
|
|
|
|
def normalize(x, stats):
|
|
if stats is None:
|
|
return x
|
|
return (x - stats.mean) / stats.std
|
|
|
|
|
|
def denormalize(x, stats):
|
|
if stats is None:
|
|
return x
|
|
return x * stats.std + stats.mean
|
|
|
|
def reduce_std(x, axis=None, keepdims=False):
|
|
return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))
|
|
|
|
def reduce_var(x, axis=None, keepdims=False):
|
|
m = tf.reduce_mean(x, axis=axis, keepdims=True)
|
|
devs_squared = tf.square(x - m)
|
|
return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)
|
|
|
|
def get_target_updates(vars, target_vars, tau):
|
|
logger.info('setting up target updates ...')
|
|
soft_updates = []
|
|
init_updates = []
|
|
assert len(vars) == len(target_vars)
|
|
for var, target_var in zip(vars, target_vars):
|
|
logger.info(' {} <- {}'.format(target_var.name, var.name))
|
|
init_updates.append(tf.assign(target_var, var))
|
|
soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var))
|
|
assert len(init_updates) == len(vars)
|
|
assert len(soft_updates) == len(vars)
|
|
return tf.group(*init_updates), tf.group(*soft_updates)
|
|
|
|
|
|
def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev):
|
|
assert len(actor.vars) == len(perturbed_actor.vars)
|
|
assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)
|
|
|
|
updates = []
|
|
for var, perturbed_var in zip(actor.vars, perturbed_actor.vars):
|
|
if var in actor.perturbable_vars:
|
|
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
|
|
updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
|
|
else:
|
|
logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
|
|
updates.append(tf.assign(perturbed_var, var))
|
|
assert len(updates) == len(actor.vars)
|
|
return tf.group(*updates)
|
|
|
|
|
|
class DDPG(object):
|
|
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
|
|
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
|
|
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
|
|
adaptive_param_noise=True, adaptive_param_noise_policy_threshold=.1,
|
|
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
|
|
# Inputs.
|
|
self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
|
|
self.obs1 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs1')
|
|
self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
|
|
self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
|
|
self.actions = tf.placeholder(tf.float32, shape=(None,) + action_shape, name='actions')
|
|
self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
|
|
self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')
|
|
|
|
# Parameters.
|
|
self.gamma = gamma
|
|
self.tau = tau
|
|
self.memory = memory
|
|
self.normalize_observations = normalize_observations
|
|
self.normalize_returns = normalize_returns
|
|
self.action_noise = action_noise
|
|
self.param_noise = param_noise
|
|
self.action_range = action_range
|
|
self.return_range = return_range
|
|
self.observation_range = observation_range
|
|
self.critic = critic
|
|
self.actor = actor
|
|
self.actor_lr = actor_lr
|
|
self.critic_lr = critic_lr
|
|
self.clip_norm = clip_norm
|
|
self.enable_popart = enable_popart
|
|
self.reward_scale = reward_scale
|
|
self.batch_size = batch_size
|
|
self.stats_sample = None
|
|
self.critic_l2_reg = critic_l2_reg
|
|
|
|
# Observation normalization.
|
|
if self.normalize_observations:
|
|
with tf.variable_scope('obs_rms'):
|
|
self.obs_rms = RunningMeanStd(shape=observation_shape)
|
|
else:
|
|
self.obs_rms = None
|
|
normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms),
|
|
self.observation_range[0], self.observation_range[1])
|
|
normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms),
|
|
self.observation_range[0], self.observation_range[1])
|
|
|
|
# Return normalization.
|
|
if self.normalize_returns:
|
|
with tf.variable_scope('ret_rms'):
|
|
self.ret_rms = RunningMeanStd()
|
|
else:
|
|
self.ret_rms = None
|
|
|
|
# Create target networks.
|
|
target_actor = copy(actor)
|
|
target_actor.name = 'target_actor'
|
|
self.target_actor = target_actor
|
|
target_critic = copy(critic)
|
|
target_critic.name = 'target_critic'
|
|
self.target_critic = target_critic
|
|
|
|
# Create networks and core TF parts that are shared across setup parts.
|
|
self.actor_tf = actor(normalized_obs0)
|
|
self.normalized_critic_tf = critic(normalized_obs0, self.actions)
|
|
self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
|
self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True)
|
|
self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
|
Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)
|
|
self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1
|
|
|
|
# Set up parts.
|
|
if self.param_noise is not None:
|
|
self.setup_param_noise(normalized_obs0)
|
|
self.setup_actor_optimizer()
|
|
self.setup_critic_optimizer()
|
|
if self.normalize_returns and self.enable_popart:
|
|
self.setup_popart()
|
|
self.setup_stats()
|
|
self.setup_target_network_updates()
|
|
|
|
self.initial_state = None # recurrent architectures not supported yet
|
|
|
|
def setup_target_network_updates(self):
|
|
actor_init_updates, actor_soft_updates = get_target_updates(self.actor.vars, self.target_actor.vars, self.tau)
|
|
critic_init_updates, critic_soft_updates = get_target_updates(self.critic.vars, self.target_critic.vars, self.tau)
|
|
self.target_init_updates = [actor_init_updates, critic_init_updates]
|
|
self.target_soft_updates = [actor_soft_updates, critic_soft_updates]
|
|
|
|
def setup_param_noise(self, normalized_obs0):
|
|
assert self.param_noise is not None
|
|
|
|
# Configure perturbed actor.
|
|
param_noise_actor = copy(self.actor)
|
|
param_noise_actor.name = 'param_noise_actor'
|
|
self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
|
|
logger.info('setting up param noise')
|
|
self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)
|
|
|
|
# Configure separate copy for stddev adoption.
|
|
adaptive_param_noise_actor = copy(self.actor)
|
|
adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
|
|
adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
|
|
self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
|
|
self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
|
|
|
|
def setup_actor_optimizer(self):
|
|
logger.info('setting up actor optimizer')
|
|
self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
|
|
actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars]
|
|
actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
|
|
logger.info(' actor shapes: {}'.format(actor_shapes))
|
|
logger.info(' actor params: {}'.format(actor_nb_params))
|
|
self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm)
|
|
self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
|
|
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
|
|
|
def setup_critic_optimizer(self):
|
|
logger.info('setting up critic optimizer')
|
|
normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
|
|
self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
|
|
if self.critic_l2_reg > 0.:
|
|
critic_reg_vars = [var for var in self.critic.trainable_vars if 'kernel' in var.name and 'output' not in var.name]
|
|
for var in critic_reg_vars:
|
|
logger.info(' regularizing: {}'.format(var.name))
|
|
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
|
|
critic_reg = tc.layers.apply_regularization(
|
|
tc.layers.l2_regularizer(self.critic_l2_reg),
|
|
weights_list=critic_reg_vars
|
|
)
|
|
self.critic_loss += critic_reg
|
|
critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars]
|
|
critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
|
|
logger.info(' critic shapes: {}'.format(critic_shapes))
|
|
logger.info(' critic params: {}'.format(critic_nb_params))
|
|
self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm)
|
|
self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
|
|
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
|
|
|
def setup_popart(self):
|
|
# See https://arxiv.org/pdf/1602.07714.pdf for details.
|
|
self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
|
|
new_std = self.ret_rms.std
|
|
self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
|
|
new_mean = self.ret_rms.mean
|
|
|
|
self.renormalize_Q_outputs_op = []
|
|
for vs in [self.critic.output_vars, self.target_critic.output_vars]:
|
|
assert len(vs) == 2
|
|
M, b = vs
|
|
assert 'kernel' in M.name
|
|
assert 'bias' in b.name
|
|
assert M.get_shape()[-1] == 1
|
|
assert b.get_shape()[-1] == 1
|
|
self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)]
|
|
self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)]
|
|
|
|
def setup_stats(self):
|
|
ops = []
|
|
names = []
|
|
|
|
if self.normalize_returns:
|
|
ops += [self.ret_rms.mean, self.ret_rms.std]
|
|
names += ['ret_rms_mean', 'ret_rms_std']
|
|
|
|
if self.normalize_observations:
|
|
ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
|
|
names += ['obs_rms_mean', 'obs_rms_std']
|
|
|
|
ops += [tf.reduce_mean(self.critic_tf)]
|
|
names += ['reference_Q_mean']
|
|
ops += [reduce_std(self.critic_tf)]
|
|
names += ['reference_Q_std']
|
|
|
|
ops += [tf.reduce_mean(self.critic_with_actor_tf)]
|
|
names += ['reference_actor_Q_mean']
|
|
ops += [reduce_std(self.critic_with_actor_tf)]
|
|
names += ['reference_actor_Q_std']
|
|
|
|
ops += [tf.reduce_mean(self.actor_tf)]
|
|
names += ['reference_action_mean']
|
|
ops += [reduce_std(self.actor_tf)]
|
|
names += ['reference_action_std']
|
|
|
|
if self.param_noise:
|
|
ops += [tf.reduce_mean(self.perturbed_actor_tf)]
|
|
names += ['reference_perturbed_action_mean']
|
|
ops += [reduce_std(self.perturbed_actor_tf)]
|
|
names += ['reference_perturbed_action_std']
|
|
|
|
self.stats_ops = ops
|
|
self.stats_names = names
|
|
|
|
def step(self, obs, apply_noise=True, compute_Q=True):
|
|
if self.param_noise is not None and apply_noise:
|
|
actor_tf = self.perturbed_actor_tf
|
|
else:
|
|
actor_tf = self.actor_tf
|
|
feed_dict = {self.obs0: U.adjust_shape(self.obs0, [obs])}
|
|
if compute_Q:
|
|
action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
|
|
else:
|
|
action = self.sess.run(actor_tf, feed_dict=feed_dict)
|
|
q = None
|
|
|
|
if self.action_noise is not None and apply_noise:
|
|
noise = self.action_noise()
|
|
assert noise.shape == action[0].shape
|
|
action += noise
|
|
action = np.clip(action, self.action_range[0], self.action_range[1])
|
|
|
|
|
|
return action, q, None, None
|
|
|
|
def store_transition(self, obs0, action, reward, obs1, terminal1):
|
|
reward *= self.reward_scale
|
|
|
|
B = obs0.shape[0]
|
|
for b in range(B):
|
|
self.memory.append(obs0[b], action[b], reward[b], obs1[b], terminal1[b])
|
|
if self.normalize_observations:
|
|
self.obs_rms.update(np.array([obs0[b]]))
|
|
|
|
def train(self):
|
|
# Get a batch.
|
|
batch = self.memory.sample(batch_size=self.batch_size)
|
|
|
|
if self.normalize_returns and self.enable_popart:
|
|
old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict={
|
|
self.obs1: batch['obs1'],
|
|
self.rewards: batch['rewards'],
|
|
self.terminals1: batch['terminals1'].astype('float32'),
|
|
})
|
|
self.ret_rms.update(target_Q.flatten())
|
|
self.sess.run(self.renormalize_Q_outputs_op, feed_dict={
|
|
self.old_std : np.array([old_std]),
|
|
self.old_mean : np.array([old_mean]),
|
|
})
|
|
|
|
# Run sanity check. Disabled by default since it slows down things considerably.
|
|
# print('running sanity check')
|
|
# target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
|
|
# self.obs1: batch['obs1'],
|
|
# self.rewards: batch['rewards'],
|
|
# self.terminals1: batch['terminals1'].astype('float32'),
|
|
# })
|
|
# print(target_Q_new, target_Q, new_mean, new_std)
|
|
# assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
|
|
else:
|
|
target_Q = self.sess.run(self.target_Q, feed_dict={
|
|
self.obs1: batch['obs1'],
|
|
self.rewards: batch['rewards'],
|
|
self.terminals1: batch['terminals1'].astype('float32'),
|
|
})
|
|
|
|
# Get all gradients and perform a synced update.
|
|
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
|
|
actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={
|
|
self.obs0: batch['obs0'],
|
|
self.actions: batch['actions'],
|
|
self.critic_target: target_Q,
|
|
})
|
|
self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
|
|
self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)
|
|
|
|
return critic_loss, actor_loss
|
|
|
|
def initialize(self, sess):
|
|
self.sess = sess
|
|
self.sess.run(tf.global_variables_initializer())
|
|
self.actor_optimizer.sync()
|
|
self.critic_optimizer.sync()
|
|
self.sess.run(self.target_init_updates)
|
|
|
|
def update_target_net(self):
|
|
self.sess.run(self.target_soft_updates)
|
|
|
|
def get_stats(self):
|
|
if self.stats_sample is None:
|
|
# Get a sample and keep that fixed for all further computations.
|
|
# This allows us to estimate the change in value for the same set of inputs.
|
|
self.stats_sample = self.memory.sample(batch_size=self.batch_size)
|
|
values = self.sess.run(self.stats_ops, feed_dict={
|
|
self.obs0: self.stats_sample['obs0'],
|
|
self.actions: self.stats_sample['actions'],
|
|
})
|
|
|
|
names = self.stats_names[:]
|
|
assert len(names) == len(values)
|
|
stats = dict(zip(names, values))
|
|
|
|
if self.param_noise is not None:
|
|
stats = {**stats, **self.param_noise.get_stats()}
|
|
|
|
return stats
|
|
|
|
def adapt_param_noise(self):
|
|
try:
|
|
from mpi4py import MPI
|
|
except ImportError:
|
|
MPI = None
|
|
|
|
if self.param_noise is None:
|
|
return 0.
|
|
|
|
# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
|
|
batch = self.memory.sample(batch_size=self.batch_size)
|
|
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
|
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
|
})
|
|
distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
|
|
self.obs0: batch['obs0'],
|
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
|
})
|
|
|
|
if MPI is not None:
|
|
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
|
else:
|
|
mean_distance = distance
|
|
|
|
if MPI is not None:
|
|
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
|
else:
|
|
mean_distance = distance
|
|
|
|
self.param_noise.adapt(mean_distance)
|
|
return mean_distance
|
|
|
|
def reset(self):
|
|
# Reset internal state after an episode is complete.
|
|
if self.action_noise is not None:
|
|
self.action_noise.reset()
|
|
if self.param_noise is not None:
|
|
self.sess.run(self.perturb_policy_ops, feed_dict={
|
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
|
})
|