Refactor DDPG (#111)
* run ddpg on Mujoco benchmark RUN BENCHMARKS * autopep8 * fixed all syntax in refactored ddpg * a little bit more refactoring * autopep8 * identity test with ddpg WIP * enable test_identity with ddpg * refactored ddpg RUN BENCHMARKS * autopep8 * include ddpg into style check * fixing tests RUN BENCHMARKS * set default seed to None RUN BENCHMARKS * run tests and benchmarks in separate buildkite steps RUN BENCHMARKS * cleanup pdb usage * flake8 and cleanups * re-enabled all benchmarks in run-benchmarks-new.py * flake8 complaints * deepq model builder compatible with network functions returning single tensor * remove ddpg test with test_discrete_identity * make ppo_metal use make_vec_env instead of make_atari_env * make ppo_metal use make_vec_env instead of make_atari_env * fixed syntax in ppo_metal.run_atari
This commit is contained in:
@@ -28,7 +28,7 @@ def nature_cnn(unscaled_images, **conv_kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@register("mlp")
|
@register("mlp")
|
||||||
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
|
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
|
||||||
"""
|
"""
|
||||||
Stack of fully-connected layers to be used in a policy / q-function approximator
|
Stack of fully-connected layers to be used in a policy / q-function approximator
|
||||||
|
|
||||||
@@ -49,8 +49,12 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
|
|||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
h = tf.layers.flatten(X)
|
h = tf.layers.flatten(X)
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
h = activation(fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2)))
|
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
||||||
return h, None
|
if layer_norm:
|
||||||
|
h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
|
||||||
|
h = activation(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
@@ -58,7 +62,7 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
|
|||||||
@register("cnn")
|
@register("cnn")
|
||||||
def cnn(**conv_kwargs):
|
def cnn(**conv_kwargs):
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
return nature_cnn(X, **conv_kwargs), None
|
return nature_cnn(X, **conv_kwargs)
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
@@ -72,7 +76,7 @@ def cnn_small(**conv_kwargs):
|
|||||||
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
|
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
|
||||||
h = conv_to_fc(h)
|
h = conv_to_fc(h)
|
||||||
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
|
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
|
||||||
return h, None
|
return h
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
@@ -190,7 +194,7 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
|
|||||||
activation_fn=tf.nn.relu,
|
activation_fn=tf.nn.relu,
|
||||||
**conv_kwargs)
|
**conv_kwargs)
|
||||||
|
|
||||||
return out, None
|
return out
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
||||||
@@ -212,7 +216,9 @@ def get_network_builder(name):
|
|||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if name in mapping:
|
if callable(name):
|
||||||
|
return name
|
||||||
|
elif name in mapping:
|
||||||
return mapping[name]
|
return mapping[name]
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown network type: {}'.format(name))
|
raise ValueError('Unknown network type: {}'.format(name))
|
||||||
|
@@ -139,7 +139,9 @@ def build_policy(env, policy_network, value_network=None, normalize_observation
|
|||||||
encoded_x = encode_observation(ob_space, encoded_x)
|
encoded_x = encode_observation(ob_space, encoded_x)
|
||||||
|
|
||||||
with tf.variable_scope('pi', reuse=tf.AUTO_REUSE):
|
with tf.variable_scope('pi', reuse=tf.AUTO_REUSE):
|
||||||
policy_latent, recurrent_tensors = policy_network(encoded_x)
|
policy_latent = policy_network(encoded_x)
|
||||||
|
if isinstance(policy_latent, tuple):
|
||||||
|
policy_latent, recurrent_tensors = policy_latent
|
||||||
|
|
||||||
if recurrent_tensors is not None:
|
if recurrent_tensors is not None:
|
||||||
# recurrent architecture, need a few more steps
|
# recurrent architecture, need a few more steps
|
||||||
@@ -160,7 +162,8 @@ def build_policy(env, policy_network, value_network=None, normalize_observation
|
|||||||
assert callable(_v_net)
|
assert callable(_v_net)
|
||||||
|
|
||||||
with tf.variable_scope('vf', reuse=tf.AUTO_REUSE):
|
with tf.variable_scope('vf', reuse=tf.AUTO_REUSE):
|
||||||
vf_latent, _ = _v_net(encoded_x)
|
# TODO recurrent architectures are not supported with value_network=copy yet
|
||||||
|
vf_latent = _v_net(encoded_x)
|
||||||
|
|
||||||
policy = PolicyWithValue(
|
policy = PolicyWithValue(
|
||||||
env=env,
|
env=env,
|
||||||
|
@@ -14,13 +14,17 @@ learn_kwargs = {
|
|||||||
'a2c' : {},
|
'a2c' : {},
|
||||||
'acktr': {},
|
'acktr': {},
|
||||||
'deepq': {},
|
'deepq': {},
|
||||||
|
'ddpg': dict(nb_epochs=None, layer_norm=True),
|
||||||
'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0),
|
'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0),
|
||||||
'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01)
|
'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
algos_disc = ['a2c', 'deepq', 'ppo2', 'trpo_mpi']
|
||||||
|
algos_cont = ['a2c', 'ddpg', 'ppo2', 'trpo_mpi']
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("alg", learn_kwargs.keys())
|
@pytest.mark.parametrize("alg", algos_disc)
|
||||||
def test_discrete_identity(alg):
|
def test_discrete_identity(alg):
|
||||||
'''
|
'''
|
||||||
Test if the algorithm (with an mlp policy)
|
Test if the algorithm (with an mlp policy)
|
||||||
@@ -35,7 +39,7 @@ def test_discrete_identity(alg):
|
|||||||
simple_test(env_fn, learn_fn, 0.9)
|
simple_test(env_fn, learn_fn, 0.9)
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("alg", ['a2c', 'ppo2', 'trpo_mpi'])
|
@pytest.mark.parametrize("alg", algos_cont)
|
||||||
def test_continuous_identity(alg):
|
def test_continuous_identity(alg):
|
||||||
'''
|
'''
|
||||||
Test if the algorithm (with an mlp policy)
|
Test if the algorithm (with an mlp policy)
|
||||||
@@ -51,5 +55,5 @@ def test_continuous_identity(alg):
|
|||||||
simple_test(env_fn, learn_fn, -0.1)
|
simple_test(env_fn, learn_fn, -0.1)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_continuous_identity('a2c')
|
test_continuous_identity('ddpg')
|
||||||
|
|
||||||
|
@@ -1,378 +1,240 @@
|
|||||||
from copy import copy
|
import os
|
||||||
from functools import reduce
|
import time
|
||||||
|
from collections import deque
|
||||||
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
from baselines.ddpg.ddpg_learner import DDPG
|
||||||
import tensorflow as tf
|
from baselines.ddpg.models import Actor, Critic
|
||||||
import tensorflow.contrib as tc
|
from baselines.ddpg.memory import Memory
|
||||||
|
from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||||
|
|
||||||
|
import baselines.common.tf_util as U
|
||||||
|
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
from baselines.common.mpi_adam import MpiAdam
|
import numpy as np
|
||||||
import baselines.common.tf_util as U
|
|
||||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
|
||||||
def normalize(x, stats):
|
|
||||||
if stats is None:
|
def learn(network, env,
|
||||||
|
seed=None,
|
||||||
|
total_timesteps=None,
|
||||||
|
nb_epochs=None, # with default settings, perform 1M steps total
|
||||||
|
nb_epoch_cycles=20,
|
||||||
|
nb_rollout_steps=100,
|
||||||
|
reward_scale=1.0,
|
||||||
|
render=False,
|
||||||
|
render_eval=False,
|
||||||
|
noise_type='adaptive-param_0.2',
|
||||||
|
normalize_returns=False,
|
||||||
|
normalize_observations=True,
|
||||||
|
critic_l2_reg=1e-2,
|
||||||
|
actor_lr=1e-4,
|
||||||
|
critic_lr=1e-3,
|
||||||
|
popart=False,
|
||||||
|
gamma=0.99,
|
||||||
|
clip_norm=None,
|
||||||
|
nb_train_steps=50, # per epoch cycle and MPI worker,
|
||||||
|
nb_eval_steps=100,
|
||||||
|
batch_size=64, # per MPI worker
|
||||||
|
tau=0.01,
|
||||||
|
eval_env=None,
|
||||||
|
param_noise_adaption_interval=50,
|
||||||
|
**network_kwargs):
|
||||||
|
|
||||||
|
|
||||||
|
if total_timesteps is not None:
|
||||||
|
assert nb_epochs is None
|
||||||
|
nb_epochs = int(total_timesteps) // (nb_epoch_cycles * nb_rollout_steps)
|
||||||
|
else:
|
||||||
|
nb_epochs = 500
|
||||||
|
|
||||||
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
nb_actions = env.action_space.shape[-1]
|
||||||
|
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
||||||
|
|
||||||
|
memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape)
|
||||||
|
critic = Critic(network=network, **network_kwargs)
|
||||||
|
actor = Actor(nb_actions, network=network, **network_kwargs)
|
||||||
|
|
||||||
|
action_noise = None
|
||||||
|
param_noise = None
|
||||||
|
nb_actions = env.action_space.shape[-1]
|
||||||
|
if noise_type is not None:
|
||||||
|
for current_noise_type in noise_type.split(','):
|
||||||
|
current_noise_type = current_noise_type.strip()
|
||||||
|
if current_noise_type == 'none':
|
||||||
|
pass
|
||||||
|
elif 'adaptive-param' in current_noise_type:
|
||||||
|
_, stddev = current_noise_type.split('_')
|
||||||
|
param_noise = AdaptiveParamNoiseSpec(initial_stddev=float(stddev), desired_action_stddev=float(stddev))
|
||||||
|
elif 'normal' in current_noise_type:
|
||||||
|
_, stddev = current_noise_type.split('_')
|
||||||
|
action_noise = NormalActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
|
||||||
|
elif 'ou' in current_noise_type:
|
||||||
|
_, stddev = current_noise_type.split('_')
|
||||||
|
action_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
|
||||||
|
else:
|
||||||
|
raise RuntimeError('unknown noise type "{}"'.format(current_noise_type))
|
||||||
|
|
||||||
|
max_action = env.action_space.high
|
||||||
|
logger.info('scaling actions by {} before executing in env'.format(max_action))
|
||||||
|
agent = DDPG(actor, critic, memory, env.observation_space.shape, env.action_space.shape,
|
||||||
|
gamma=gamma, tau=tau, normalize_returns=normalize_returns, normalize_observations=normalize_observations,
|
||||||
|
batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
|
||||||
|
actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
|
||||||
|
reward_scale=reward_scale)
|
||||||
|
logger.info('Using agent with the following configuration:')
|
||||||
|
logger.info(str(agent.__dict__.items()))
|
||||||
|
|
||||||
|
eval_episode_rewards_history = deque(maxlen=100)
|
||||||
|
episode_rewards_history = deque(maxlen=100)
|
||||||
|
sess = U.get_session()
|
||||||
|
# Prepare everything.
|
||||||
|
agent.initialize(sess)
|
||||||
|
sess.graph.finalize()
|
||||||
|
|
||||||
|
agent.reset()
|
||||||
|
obs = env.reset()
|
||||||
|
if eval_env is not None:
|
||||||
|
eval_obs = eval_env.reset()
|
||||||
|
done = False
|
||||||
|
episode_reward = 0.
|
||||||
|
episode_step = 0
|
||||||
|
episodes = 0
|
||||||
|
t = 0
|
||||||
|
|
||||||
|
epoch = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
epoch_episode_rewards = []
|
||||||
|
epoch_episode_steps = []
|
||||||
|
epoch_actions = []
|
||||||
|
epoch_qs = []
|
||||||
|
epoch_episodes = 0
|
||||||
|
for epoch in range(nb_epochs):
|
||||||
|
for cycle in range(nb_epoch_cycles):
|
||||||
|
# Perform rollouts.
|
||||||
|
for t_rollout in range(nb_rollout_steps):
|
||||||
|
# Predict next action.
|
||||||
|
action, q, _, _ = agent.step(obs, apply_noise=True, compute_Q=True)
|
||||||
|
assert action.shape == env.action_space.shape
|
||||||
|
|
||||||
|
# Execute next action.
|
||||||
|
if rank == 0 and render:
|
||||||
|
env.render()
|
||||||
|
assert max_action.shape == action.shape
|
||||||
|
new_obs, r, done, info = env.step(max_action * action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
||||||
|
t += 1
|
||||||
|
if rank == 0 and render:
|
||||||
|
env.render()
|
||||||
|
episode_reward += r
|
||||||
|
episode_step += 1
|
||||||
|
|
||||||
|
# Book-keeping.
|
||||||
|
epoch_actions.append(action)
|
||||||
|
epoch_qs.append(q)
|
||||||
|
agent.store_transition(obs, action, r, new_obs, done)
|
||||||
|
obs = new_obs
|
||||||
|
|
||||||
|
if done:
|
||||||
|
# Episode done.
|
||||||
|
epoch_episode_rewards.append(episode_reward)
|
||||||
|
episode_rewards_history.append(episode_reward)
|
||||||
|
epoch_episode_steps.append(episode_step)
|
||||||
|
episode_reward = 0.
|
||||||
|
episode_step = 0
|
||||||
|
epoch_episodes += 1
|
||||||
|
episodes += 1
|
||||||
|
|
||||||
|
agent.reset()
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
# Train.
|
||||||
|
epoch_actor_losses = []
|
||||||
|
epoch_critic_losses = []
|
||||||
|
epoch_adaptive_distances = []
|
||||||
|
for t_train in range(nb_train_steps):
|
||||||
|
# Adapt param noise, if necessary.
|
||||||
|
if memory.nb_entries >= batch_size and t_train % param_noise_adaption_interval == 0:
|
||||||
|
distance = agent.adapt_param_noise()
|
||||||
|
epoch_adaptive_distances.append(distance)
|
||||||
|
|
||||||
|
cl, al = agent.train()
|
||||||
|
epoch_critic_losses.append(cl)
|
||||||
|
epoch_actor_losses.append(al)
|
||||||
|
agent.update_target_net()
|
||||||
|
|
||||||
|
# Evaluate.
|
||||||
|
eval_episode_rewards = []
|
||||||
|
eval_qs = []
|
||||||
|
if eval_env is not None:
|
||||||
|
eval_episode_reward = 0.
|
||||||
|
for t_rollout in range(nb_eval_steps):
|
||||||
|
eval_action, eval_q, _, _ = agent.step(eval_obs, apply_noise=False, compute_Q=True)
|
||||||
|
eval_obs, eval_r, eval_done, eval_info = eval_env.step(max_action * eval_action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
||||||
|
if render_eval:
|
||||||
|
eval_env.render()
|
||||||
|
eval_episode_reward += eval_r
|
||||||
|
|
||||||
|
eval_qs.append(eval_q)
|
||||||
|
if eval_done:
|
||||||
|
eval_obs = eval_env.reset()
|
||||||
|
eval_episode_rewards.append(eval_episode_reward)
|
||||||
|
eval_episode_rewards_history.append(eval_episode_reward)
|
||||||
|
eval_episode_reward = 0.
|
||||||
|
|
||||||
|
mpi_size = MPI.COMM_WORLD.Get_size()
|
||||||
|
# Log stats.
|
||||||
|
# XXX shouldn't call np.mean on variable length lists
|
||||||
|
duration = time.time() - start_time
|
||||||
|
stats = agent.get_stats()
|
||||||
|
combined_stats = stats.copy()
|
||||||
|
combined_stats['rollout/return'] = np.mean(epoch_episode_rewards)
|
||||||
|
combined_stats['rollout/return_history'] = np.mean(episode_rewards_history)
|
||||||
|
combined_stats['rollout/episode_steps'] = np.mean(epoch_episode_steps)
|
||||||
|
combined_stats['rollout/actions_mean'] = np.mean(epoch_actions)
|
||||||
|
combined_stats['rollout/Q_mean'] = np.mean(epoch_qs)
|
||||||
|
combined_stats['train/loss_actor'] = np.mean(epoch_actor_losses)
|
||||||
|
combined_stats['train/loss_critic'] = np.mean(epoch_critic_losses)
|
||||||
|
combined_stats['train/param_noise_distance'] = np.mean(epoch_adaptive_distances)
|
||||||
|
combined_stats['total/duration'] = duration
|
||||||
|
combined_stats['total/steps_per_second'] = float(t) / float(duration)
|
||||||
|
combined_stats['total/episodes'] = episodes
|
||||||
|
combined_stats['rollout/episodes'] = epoch_episodes
|
||||||
|
combined_stats['rollout/actions_std'] = np.std(epoch_actions)
|
||||||
|
# Evaluation statistics.
|
||||||
|
if eval_env is not None:
|
||||||
|
combined_stats['eval/return'] = eval_episode_rewards
|
||||||
|
combined_stats['eval/return_history'] = np.mean(eval_episode_rewards_history)
|
||||||
|
combined_stats['eval/Q'] = eval_qs
|
||||||
|
combined_stats['eval/episodes'] = len(eval_episode_rewards)
|
||||||
|
def as_scalar(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
assert x.size == 1
|
||||||
|
return x[0]
|
||||||
|
elif np.isscalar(x):
|
||||||
return x
|
return x
|
||||||
return (x - stats.mean) / stats.std
|
|
||||||
|
|
||||||
|
|
||||||
def denormalize(x, stats):
|
|
||||||
if stats is None:
|
|
||||||
return x
|
|
||||||
return x * stats.std + stats.mean
|
|
||||||
|
|
||||||
def reduce_std(x, axis=None, keepdims=False):
|
|
||||||
return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))
|
|
||||||
|
|
||||||
def reduce_var(x, axis=None, keepdims=False):
|
|
||||||
m = tf.reduce_mean(x, axis=axis, keepdims=True)
|
|
||||||
devs_squared = tf.square(x - m)
|
|
||||||
return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)
|
|
||||||
|
|
||||||
def get_target_updates(vars, target_vars, tau):
|
|
||||||
logger.info('setting up target updates ...')
|
|
||||||
soft_updates = []
|
|
||||||
init_updates = []
|
|
||||||
assert len(vars) == len(target_vars)
|
|
||||||
for var, target_var in zip(vars, target_vars):
|
|
||||||
logger.info(' {} <- {}'.format(target_var.name, var.name))
|
|
||||||
init_updates.append(tf.assign(target_var, var))
|
|
||||||
soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var))
|
|
||||||
assert len(init_updates) == len(vars)
|
|
||||||
assert len(soft_updates) == len(vars)
|
|
||||||
return tf.group(*init_updates), tf.group(*soft_updates)
|
|
||||||
|
|
||||||
|
|
||||||
def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev):
|
|
||||||
assert len(actor.vars) == len(perturbed_actor.vars)
|
|
||||||
assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)
|
|
||||||
|
|
||||||
updates = []
|
|
||||||
for var, perturbed_var in zip(actor.vars, perturbed_actor.vars):
|
|
||||||
if var in actor.perturbable_vars:
|
|
||||||
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
|
|
||||||
updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
|
|
||||||
else:
|
else:
|
||||||
logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
|
raise ValueError('expected scalar, got %s'%x)
|
||||||
updates.append(tf.assign(perturbed_var, var))
|
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([as_scalar(x) for x in combined_stats.values()]))
|
||||||
assert len(updates) == len(actor.vars)
|
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
||||||
return tf.group(*updates)
|
|
||||||
|
# Total statistics.
|
||||||
|
combined_stats['total/epochs'] = epoch + 1
|
||||||
|
combined_stats['total/steps'] = t
|
||||||
|
|
||||||
|
for key in sorted(combined_stats.keys()):
|
||||||
|
logger.record_tabular(key, combined_stats[key])
|
||||||
|
logger.dump_tabular()
|
||||||
|
logger.info('')
|
||||||
|
logdir = logger.get_dir()
|
||||||
|
if rank == 0 and logdir:
|
||||||
|
if hasattr(env, 'get_state'):
|
||||||
|
with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as f:
|
||||||
|
pickle.dump(env.get_state(), f)
|
||||||
|
if eval_env and hasattr(eval_env, 'get_state'):
|
||||||
|
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
|
||||||
|
pickle.dump(eval_env.get_state(), f)
|
||||||
|
|
||||||
|
|
||||||
class DDPG(object):
|
return agent
|
||||||
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
|
|
||||||
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
|
|
||||||
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
|
|
||||||
adaptive_param_noise=True, adaptive_param_noise_policy_threshold=.1,
|
|
||||||
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
|
|
||||||
# Inputs.
|
|
||||||
self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
|
|
||||||
self.obs1 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs1')
|
|
||||||
self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
|
|
||||||
self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
|
|
||||||
self.actions = tf.placeholder(tf.float32, shape=(None,) + action_shape, name='actions')
|
|
||||||
self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
|
|
||||||
self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')
|
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
self.gamma = gamma
|
|
||||||
self.tau = tau
|
|
||||||
self.memory = memory
|
|
||||||
self.normalize_observations = normalize_observations
|
|
||||||
self.normalize_returns = normalize_returns
|
|
||||||
self.action_noise = action_noise
|
|
||||||
self.param_noise = param_noise
|
|
||||||
self.action_range = action_range
|
|
||||||
self.return_range = return_range
|
|
||||||
self.observation_range = observation_range
|
|
||||||
self.critic = critic
|
|
||||||
self.actor = actor
|
|
||||||
self.actor_lr = actor_lr
|
|
||||||
self.critic_lr = critic_lr
|
|
||||||
self.clip_norm = clip_norm
|
|
||||||
self.enable_popart = enable_popart
|
|
||||||
self.reward_scale = reward_scale
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.stats_sample = None
|
|
||||||
self.critic_l2_reg = critic_l2_reg
|
|
||||||
|
|
||||||
# Observation normalization.
|
|
||||||
if self.normalize_observations:
|
|
||||||
with tf.variable_scope('obs_rms'):
|
|
||||||
self.obs_rms = RunningMeanStd(shape=observation_shape)
|
|
||||||
else:
|
|
||||||
self.obs_rms = None
|
|
||||||
normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms),
|
|
||||||
self.observation_range[0], self.observation_range[1])
|
|
||||||
normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms),
|
|
||||||
self.observation_range[0], self.observation_range[1])
|
|
||||||
|
|
||||||
# Return normalization.
|
|
||||||
if self.normalize_returns:
|
|
||||||
with tf.variable_scope('ret_rms'):
|
|
||||||
self.ret_rms = RunningMeanStd()
|
|
||||||
else:
|
|
||||||
self.ret_rms = None
|
|
||||||
|
|
||||||
# Create target networks.
|
|
||||||
target_actor = copy(actor)
|
|
||||||
target_actor.name = 'target_actor'
|
|
||||||
self.target_actor = target_actor
|
|
||||||
target_critic = copy(critic)
|
|
||||||
target_critic.name = 'target_critic'
|
|
||||||
self.target_critic = target_critic
|
|
||||||
|
|
||||||
# Create networks and core TF parts that are shared across setup parts.
|
|
||||||
self.actor_tf = actor(normalized_obs0)
|
|
||||||
self.normalized_critic_tf = critic(normalized_obs0, self.actions)
|
|
||||||
self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
|
||||||
self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True)
|
|
||||||
self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
|
||||||
Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)
|
|
||||||
self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1
|
|
||||||
|
|
||||||
# Set up parts.
|
|
||||||
if self.param_noise is not None:
|
|
||||||
self.setup_param_noise(normalized_obs0)
|
|
||||||
self.setup_actor_optimizer()
|
|
||||||
self.setup_critic_optimizer()
|
|
||||||
if self.normalize_returns and self.enable_popart:
|
|
||||||
self.setup_popart()
|
|
||||||
self.setup_stats()
|
|
||||||
self.setup_target_network_updates()
|
|
||||||
|
|
||||||
def setup_target_network_updates(self):
|
|
||||||
actor_init_updates, actor_soft_updates = get_target_updates(self.actor.vars, self.target_actor.vars, self.tau)
|
|
||||||
critic_init_updates, critic_soft_updates = get_target_updates(self.critic.vars, self.target_critic.vars, self.tau)
|
|
||||||
self.target_init_updates = [actor_init_updates, critic_init_updates]
|
|
||||||
self.target_soft_updates = [actor_soft_updates, critic_soft_updates]
|
|
||||||
|
|
||||||
def setup_param_noise(self, normalized_obs0):
|
|
||||||
assert self.param_noise is not None
|
|
||||||
|
|
||||||
# Configure perturbed actor.
|
|
||||||
param_noise_actor = copy(self.actor)
|
|
||||||
param_noise_actor.name = 'param_noise_actor'
|
|
||||||
self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
|
|
||||||
logger.info('setting up param noise')
|
|
||||||
self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)
|
|
||||||
|
|
||||||
# Configure separate copy for stddev adoption.
|
|
||||||
adaptive_param_noise_actor = copy(self.actor)
|
|
||||||
adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
|
|
||||||
adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
|
|
||||||
self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
|
|
||||||
self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
|
|
||||||
|
|
||||||
def setup_actor_optimizer(self):
|
|
||||||
logger.info('setting up actor optimizer')
|
|
||||||
self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
|
|
||||||
actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars]
|
|
||||||
actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
|
|
||||||
logger.info(' actor shapes: {}'.format(actor_shapes))
|
|
||||||
logger.info(' actor params: {}'.format(actor_nb_params))
|
|
||||||
self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm)
|
|
||||||
self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
|
|
||||||
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
|
||||||
|
|
||||||
def setup_critic_optimizer(self):
|
|
||||||
logger.info('setting up critic optimizer')
|
|
||||||
normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
|
|
||||||
self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
|
|
||||||
if self.critic_l2_reg > 0.:
|
|
||||||
critic_reg_vars = [var for var in self.critic.trainable_vars if 'kernel' in var.name and 'output' not in var.name]
|
|
||||||
for var in critic_reg_vars:
|
|
||||||
logger.info(' regularizing: {}'.format(var.name))
|
|
||||||
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
|
|
||||||
critic_reg = tc.layers.apply_regularization(
|
|
||||||
tc.layers.l2_regularizer(self.critic_l2_reg),
|
|
||||||
weights_list=critic_reg_vars
|
|
||||||
)
|
|
||||||
self.critic_loss += critic_reg
|
|
||||||
critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars]
|
|
||||||
critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
|
|
||||||
logger.info(' critic shapes: {}'.format(critic_shapes))
|
|
||||||
logger.info(' critic params: {}'.format(critic_nb_params))
|
|
||||||
self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm)
|
|
||||||
self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
|
|
||||||
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
|
||||||
|
|
||||||
def setup_popart(self):
|
|
||||||
# See https://arxiv.org/pdf/1602.07714.pdf for details.
|
|
||||||
self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
|
|
||||||
new_std = self.ret_rms.std
|
|
||||||
self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
|
|
||||||
new_mean = self.ret_rms.mean
|
|
||||||
|
|
||||||
self.renormalize_Q_outputs_op = []
|
|
||||||
for vs in [self.critic.output_vars, self.target_critic.output_vars]:
|
|
||||||
assert len(vs) == 2
|
|
||||||
M, b = vs
|
|
||||||
assert 'kernel' in M.name
|
|
||||||
assert 'bias' in b.name
|
|
||||||
assert M.get_shape()[-1] == 1
|
|
||||||
assert b.get_shape()[-1] == 1
|
|
||||||
self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)]
|
|
||||||
self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)]
|
|
||||||
|
|
||||||
def setup_stats(self):
|
|
||||||
ops = []
|
|
||||||
names = []
|
|
||||||
|
|
||||||
if self.normalize_returns:
|
|
||||||
ops += [self.ret_rms.mean, self.ret_rms.std]
|
|
||||||
names += ['ret_rms_mean', 'ret_rms_std']
|
|
||||||
|
|
||||||
if self.normalize_observations:
|
|
||||||
ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
|
|
||||||
names += ['obs_rms_mean', 'obs_rms_std']
|
|
||||||
|
|
||||||
ops += [tf.reduce_mean(self.critic_tf)]
|
|
||||||
names += ['reference_Q_mean']
|
|
||||||
ops += [reduce_std(self.critic_tf)]
|
|
||||||
names += ['reference_Q_std']
|
|
||||||
|
|
||||||
ops += [tf.reduce_mean(self.critic_with_actor_tf)]
|
|
||||||
names += ['reference_actor_Q_mean']
|
|
||||||
ops += [reduce_std(self.critic_with_actor_tf)]
|
|
||||||
names += ['reference_actor_Q_std']
|
|
||||||
|
|
||||||
ops += [tf.reduce_mean(self.actor_tf)]
|
|
||||||
names += ['reference_action_mean']
|
|
||||||
ops += [reduce_std(self.actor_tf)]
|
|
||||||
names += ['reference_action_std']
|
|
||||||
|
|
||||||
if self.param_noise:
|
|
||||||
ops += [tf.reduce_mean(self.perturbed_actor_tf)]
|
|
||||||
names += ['reference_perturbed_action_mean']
|
|
||||||
ops += [reduce_std(self.perturbed_actor_tf)]
|
|
||||||
names += ['reference_perturbed_action_std']
|
|
||||||
|
|
||||||
self.stats_ops = ops
|
|
||||||
self.stats_names = names
|
|
||||||
|
|
||||||
def pi(self, obs, apply_noise=True, compute_Q=True):
|
|
||||||
if self.param_noise is not None and apply_noise:
|
|
||||||
actor_tf = self.perturbed_actor_tf
|
|
||||||
else:
|
|
||||||
actor_tf = self.actor_tf
|
|
||||||
feed_dict = {self.obs0: [obs]}
|
|
||||||
if compute_Q:
|
|
||||||
action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
|
|
||||||
else:
|
|
||||||
action = self.sess.run(actor_tf, feed_dict=feed_dict)
|
|
||||||
q = None
|
|
||||||
action = action.flatten()
|
|
||||||
if self.action_noise is not None and apply_noise:
|
|
||||||
noise = self.action_noise()
|
|
||||||
assert noise.shape == action.shape
|
|
||||||
action += noise
|
|
||||||
action = np.clip(action, self.action_range[0], self.action_range[1])
|
|
||||||
return action, q
|
|
||||||
|
|
||||||
def store_transition(self, obs0, action, reward, obs1, terminal1):
|
|
||||||
reward *= self.reward_scale
|
|
||||||
self.memory.append(obs0, action, reward, obs1, terminal1)
|
|
||||||
if self.normalize_observations:
|
|
||||||
self.obs_rms.update(np.array([obs0]))
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
# Get a batch.
|
|
||||||
batch = self.memory.sample(batch_size=self.batch_size)
|
|
||||||
|
|
||||||
if self.normalize_returns and self.enable_popart:
|
|
||||||
old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict={
|
|
||||||
self.obs1: batch['obs1'],
|
|
||||||
self.rewards: batch['rewards'],
|
|
||||||
self.terminals1: batch['terminals1'].astype('float32'),
|
|
||||||
})
|
|
||||||
self.ret_rms.update(target_Q.flatten())
|
|
||||||
self.sess.run(self.renormalize_Q_outputs_op, feed_dict={
|
|
||||||
self.old_std : np.array([old_std]),
|
|
||||||
self.old_mean : np.array([old_mean]),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Run sanity check. Disabled by default since it slows down things considerably.
|
|
||||||
# print('running sanity check')
|
|
||||||
# target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
|
|
||||||
# self.obs1: batch['obs1'],
|
|
||||||
# self.rewards: batch['rewards'],
|
|
||||||
# self.terminals1: batch['terminals1'].astype('float32'),
|
|
||||||
# })
|
|
||||||
# print(target_Q_new, target_Q, new_mean, new_std)
|
|
||||||
# assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
|
|
||||||
else:
|
|
||||||
target_Q = self.sess.run(self.target_Q, feed_dict={
|
|
||||||
self.obs1: batch['obs1'],
|
|
||||||
self.rewards: batch['rewards'],
|
|
||||||
self.terminals1: batch['terminals1'].astype('float32'),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Get all gradients and perform a synced update.
|
|
||||||
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
|
|
||||||
actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={
|
|
||||||
self.obs0: batch['obs0'],
|
|
||||||
self.actions: batch['actions'],
|
|
||||||
self.critic_target: target_Q,
|
|
||||||
})
|
|
||||||
self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
|
|
||||||
self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)
|
|
||||||
|
|
||||||
return critic_loss, actor_loss
|
|
||||||
|
|
||||||
def initialize(self, sess):
|
|
||||||
self.sess = sess
|
|
||||||
self.sess.run(tf.global_variables_initializer())
|
|
||||||
self.actor_optimizer.sync()
|
|
||||||
self.critic_optimizer.sync()
|
|
||||||
self.sess.run(self.target_init_updates)
|
|
||||||
|
|
||||||
def update_target_net(self):
|
|
||||||
self.sess.run(self.target_soft_updates)
|
|
||||||
|
|
||||||
def get_stats(self):
|
|
||||||
if self.stats_sample is None:
|
|
||||||
# Get a sample and keep that fixed for all further computations.
|
|
||||||
# This allows us to estimate the change in value for the same set of inputs.
|
|
||||||
self.stats_sample = self.memory.sample(batch_size=self.batch_size)
|
|
||||||
values = self.sess.run(self.stats_ops, feed_dict={
|
|
||||||
self.obs0: self.stats_sample['obs0'],
|
|
||||||
self.actions: self.stats_sample['actions'],
|
|
||||||
})
|
|
||||||
|
|
||||||
names = self.stats_names[:]
|
|
||||||
assert len(names) == len(values)
|
|
||||||
stats = dict(zip(names, values))
|
|
||||||
|
|
||||||
if self.param_noise is not None:
|
|
||||||
stats = {**stats, **self.param_noise.get_stats()}
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def adapt_param_noise(self):
|
|
||||||
if self.param_noise is None:
|
|
||||||
return 0.
|
|
||||||
|
|
||||||
# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
|
|
||||||
batch = self.memory.sample(batch_size=self.batch_size)
|
|
||||||
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
|
|
||||||
self.param_noise_stddev: self.param_noise.current_stddev,
|
|
||||||
})
|
|
||||||
distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
|
|
||||||
self.obs0: batch['obs0'],
|
|
||||||
self.param_noise_stddev: self.param_noise.current_stddev,
|
|
||||||
})
|
|
||||||
|
|
||||||
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
|
||||||
self.param_noise.adapt(mean_distance)
|
|
||||||
return mean_distance
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
# Reset internal state after an episode is complete.
|
|
||||||
if self.action_noise is not None:
|
|
||||||
self.action_noise.reset()
|
|
||||||
if self.param_noise is not None:
|
|
||||||
self.sess.run(self.perturb_policy_ops, feed_dict={
|
|
||||||
self.param_noise_stddev: self.param_noise.current_stddev,
|
|
||||||
})
|
|
||||||
|
380
baselines/ddpg/ddpg_learner.py
Normal file
380
baselines/ddpg/ddpg_learner.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
from copy import copy
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.contrib as tc
|
||||||
|
|
||||||
|
from baselines import logger
|
||||||
|
from baselines.common.mpi_adam import MpiAdam
|
||||||
|
import baselines.common.tf_util as U
|
||||||
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||||
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
def normalize(x, stats):
|
||||||
|
if stats is None:
|
||||||
|
return x
|
||||||
|
return (x - stats.mean) / stats.std
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize(x, stats):
|
||||||
|
if stats is None:
|
||||||
|
return x
|
||||||
|
return x * stats.std + stats.mean
|
||||||
|
|
||||||
|
def reduce_std(x, axis=None, keepdims=False):
|
||||||
|
return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))
|
||||||
|
|
||||||
|
def reduce_var(x, axis=None, keepdims=False):
|
||||||
|
m = tf.reduce_mean(x, axis=axis, keepdims=True)
|
||||||
|
devs_squared = tf.square(x - m)
|
||||||
|
return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)
|
||||||
|
|
||||||
|
def get_target_updates(vars, target_vars, tau):
|
||||||
|
logger.info('setting up target updates ...')
|
||||||
|
soft_updates = []
|
||||||
|
init_updates = []
|
||||||
|
assert len(vars) == len(target_vars)
|
||||||
|
for var, target_var in zip(vars, target_vars):
|
||||||
|
logger.info(' {} <- {}'.format(target_var.name, var.name))
|
||||||
|
init_updates.append(tf.assign(target_var, var))
|
||||||
|
soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var))
|
||||||
|
assert len(init_updates) == len(vars)
|
||||||
|
assert len(soft_updates) == len(vars)
|
||||||
|
return tf.group(*init_updates), tf.group(*soft_updates)
|
||||||
|
|
||||||
|
|
||||||
|
def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev):
|
||||||
|
assert len(actor.vars) == len(perturbed_actor.vars)
|
||||||
|
assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
for var, perturbed_var in zip(actor.vars, perturbed_actor.vars):
|
||||||
|
if var in actor.perturbable_vars:
|
||||||
|
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
|
||||||
|
updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
|
||||||
|
else:
|
||||||
|
logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
|
||||||
|
updates.append(tf.assign(perturbed_var, var))
|
||||||
|
assert len(updates) == len(actor.vars)
|
||||||
|
return tf.group(*updates)
|
||||||
|
|
||||||
|
|
||||||
|
class DDPG(object):
|
||||||
|
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
|
||||||
|
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
|
||||||
|
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
|
||||||
|
adaptive_param_noise=True, adaptive_param_noise_policy_threshold=.1,
|
||||||
|
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
|
||||||
|
# Inputs.
|
||||||
|
self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
|
||||||
|
self.obs1 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs1')
|
||||||
|
self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
|
||||||
|
self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
|
||||||
|
self.actions = tf.placeholder(tf.float32, shape=(None,) + action_shape, name='actions')
|
||||||
|
self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
|
||||||
|
self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
self.gamma = gamma
|
||||||
|
self.tau = tau
|
||||||
|
self.memory = memory
|
||||||
|
self.normalize_observations = normalize_observations
|
||||||
|
self.normalize_returns = normalize_returns
|
||||||
|
self.action_noise = action_noise
|
||||||
|
self.param_noise = param_noise
|
||||||
|
self.action_range = action_range
|
||||||
|
self.return_range = return_range
|
||||||
|
self.observation_range = observation_range
|
||||||
|
self.critic = critic
|
||||||
|
self.actor = actor
|
||||||
|
self.actor_lr = actor_lr
|
||||||
|
self.critic_lr = critic_lr
|
||||||
|
self.clip_norm = clip_norm
|
||||||
|
self.enable_popart = enable_popart
|
||||||
|
self.reward_scale = reward_scale
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.stats_sample = None
|
||||||
|
self.critic_l2_reg = critic_l2_reg
|
||||||
|
|
||||||
|
# Observation normalization.
|
||||||
|
if self.normalize_observations:
|
||||||
|
with tf.variable_scope('obs_rms'):
|
||||||
|
self.obs_rms = RunningMeanStd(shape=observation_shape)
|
||||||
|
else:
|
||||||
|
self.obs_rms = None
|
||||||
|
normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms),
|
||||||
|
self.observation_range[0], self.observation_range[1])
|
||||||
|
normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms),
|
||||||
|
self.observation_range[0], self.observation_range[1])
|
||||||
|
|
||||||
|
# Return normalization.
|
||||||
|
if self.normalize_returns:
|
||||||
|
with tf.variable_scope('ret_rms'):
|
||||||
|
self.ret_rms = RunningMeanStd()
|
||||||
|
else:
|
||||||
|
self.ret_rms = None
|
||||||
|
|
||||||
|
# Create target networks.
|
||||||
|
target_actor = copy(actor)
|
||||||
|
target_actor.name = 'target_actor'
|
||||||
|
self.target_actor = target_actor
|
||||||
|
target_critic = copy(critic)
|
||||||
|
target_critic.name = 'target_critic'
|
||||||
|
self.target_critic = target_critic
|
||||||
|
|
||||||
|
# Create networks and core TF parts that are shared across setup parts.
|
||||||
|
self.actor_tf = actor(normalized_obs0)
|
||||||
|
self.normalized_critic_tf = critic(normalized_obs0, self.actions)
|
||||||
|
self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
||||||
|
self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True)
|
||||||
|
self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
||||||
|
Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)
|
||||||
|
self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1
|
||||||
|
|
||||||
|
# Set up parts.
|
||||||
|
if self.param_noise is not None:
|
||||||
|
self.setup_param_noise(normalized_obs0)
|
||||||
|
self.setup_actor_optimizer()
|
||||||
|
self.setup_critic_optimizer()
|
||||||
|
if self.normalize_returns and self.enable_popart:
|
||||||
|
self.setup_popart()
|
||||||
|
self.setup_stats()
|
||||||
|
self.setup_target_network_updates()
|
||||||
|
|
||||||
|
self.initial_state = None # recurrent architectures not supported yet
|
||||||
|
|
||||||
|
def setup_target_network_updates(self):
|
||||||
|
actor_init_updates, actor_soft_updates = get_target_updates(self.actor.vars, self.target_actor.vars, self.tau)
|
||||||
|
critic_init_updates, critic_soft_updates = get_target_updates(self.critic.vars, self.target_critic.vars, self.tau)
|
||||||
|
self.target_init_updates = [actor_init_updates, critic_init_updates]
|
||||||
|
self.target_soft_updates = [actor_soft_updates, critic_soft_updates]
|
||||||
|
|
||||||
|
def setup_param_noise(self, normalized_obs0):
|
||||||
|
assert self.param_noise is not None
|
||||||
|
|
||||||
|
# Configure perturbed actor.
|
||||||
|
param_noise_actor = copy(self.actor)
|
||||||
|
param_noise_actor.name = 'param_noise_actor'
|
||||||
|
self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
|
||||||
|
logger.info('setting up param noise')
|
||||||
|
self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)
|
||||||
|
|
||||||
|
# Configure separate copy for stddev adoption.
|
||||||
|
adaptive_param_noise_actor = copy(self.actor)
|
||||||
|
adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
|
||||||
|
adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
|
||||||
|
self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
|
||||||
|
self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
|
||||||
|
|
||||||
|
def setup_actor_optimizer(self):
|
||||||
|
logger.info('setting up actor optimizer')
|
||||||
|
self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
|
||||||
|
actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars]
|
||||||
|
actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
|
||||||
|
logger.info(' actor shapes: {}'.format(actor_shapes))
|
||||||
|
logger.info(' actor params: {}'.format(actor_nb_params))
|
||||||
|
self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm)
|
||||||
|
self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
|
||||||
|
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
||||||
|
|
||||||
|
def setup_critic_optimizer(self):
|
||||||
|
logger.info('setting up critic optimizer')
|
||||||
|
normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
|
||||||
|
self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
|
||||||
|
if self.critic_l2_reg > 0.:
|
||||||
|
critic_reg_vars = [var for var in self.critic.trainable_vars if 'kernel' in var.name and 'output' not in var.name]
|
||||||
|
for var in critic_reg_vars:
|
||||||
|
logger.info(' regularizing: {}'.format(var.name))
|
||||||
|
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
|
||||||
|
critic_reg = tc.layers.apply_regularization(
|
||||||
|
tc.layers.l2_regularizer(self.critic_l2_reg),
|
||||||
|
weights_list=critic_reg_vars
|
||||||
|
)
|
||||||
|
self.critic_loss += critic_reg
|
||||||
|
critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars]
|
||||||
|
critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
|
||||||
|
logger.info(' critic shapes: {}'.format(critic_shapes))
|
||||||
|
logger.info(' critic params: {}'.format(critic_nb_params))
|
||||||
|
self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm)
|
||||||
|
self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
|
||||||
|
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
||||||
|
|
||||||
|
def setup_popart(self):
|
||||||
|
# See https://arxiv.org/pdf/1602.07714.pdf for details.
|
||||||
|
self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
|
||||||
|
new_std = self.ret_rms.std
|
||||||
|
self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
|
||||||
|
new_mean = self.ret_rms.mean
|
||||||
|
|
||||||
|
self.renormalize_Q_outputs_op = []
|
||||||
|
for vs in [self.critic.output_vars, self.target_critic.output_vars]:
|
||||||
|
assert len(vs) == 2
|
||||||
|
M, b = vs
|
||||||
|
assert 'kernel' in M.name
|
||||||
|
assert 'bias' in b.name
|
||||||
|
assert M.get_shape()[-1] == 1
|
||||||
|
assert b.get_shape()[-1] == 1
|
||||||
|
self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)]
|
||||||
|
self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)]
|
||||||
|
|
||||||
|
def setup_stats(self):
|
||||||
|
ops = []
|
||||||
|
names = []
|
||||||
|
|
||||||
|
if self.normalize_returns:
|
||||||
|
ops += [self.ret_rms.mean, self.ret_rms.std]
|
||||||
|
names += ['ret_rms_mean', 'ret_rms_std']
|
||||||
|
|
||||||
|
if self.normalize_observations:
|
||||||
|
ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
|
||||||
|
names += ['obs_rms_mean', 'obs_rms_std']
|
||||||
|
|
||||||
|
ops += [tf.reduce_mean(self.critic_tf)]
|
||||||
|
names += ['reference_Q_mean']
|
||||||
|
ops += [reduce_std(self.critic_tf)]
|
||||||
|
names += ['reference_Q_std']
|
||||||
|
|
||||||
|
ops += [tf.reduce_mean(self.critic_with_actor_tf)]
|
||||||
|
names += ['reference_actor_Q_mean']
|
||||||
|
ops += [reduce_std(self.critic_with_actor_tf)]
|
||||||
|
names += ['reference_actor_Q_std']
|
||||||
|
|
||||||
|
ops += [tf.reduce_mean(self.actor_tf)]
|
||||||
|
names += ['reference_action_mean']
|
||||||
|
ops += [reduce_std(self.actor_tf)]
|
||||||
|
names += ['reference_action_std']
|
||||||
|
|
||||||
|
if self.param_noise:
|
||||||
|
ops += [tf.reduce_mean(self.perturbed_actor_tf)]
|
||||||
|
names += ['reference_perturbed_action_mean']
|
||||||
|
ops += [reduce_std(self.perturbed_actor_tf)]
|
||||||
|
names += ['reference_perturbed_action_std']
|
||||||
|
|
||||||
|
self.stats_ops = ops
|
||||||
|
self.stats_names = names
|
||||||
|
|
||||||
|
def step(self, obs, apply_noise=True, compute_Q=True):
|
||||||
|
if self.param_noise is not None and apply_noise:
|
||||||
|
actor_tf = self.perturbed_actor_tf
|
||||||
|
else:
|
||||||
|
actor_tf = self.actor_tf
|
||||||
|
feed_dict = {self.obs0: U.adjust_shape(self.obs0, [obs])}
|
||||||
|
if compute_Q:
|
||||||
|
action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
|
||||||
|
else:
|
||||||
|
action = self.sess.run(actor_tf, feed_dict=feed_dict)
|
||||||
|
q = None
|
||||||
|
action = action.flatten()
|
||||||
|
if self.action_noise is not None and apply_noise:
|
||||||
|
noise = self.action_noise()
|
||||||
|
assert noise.shape == action.shape
|
||||||
|
action += noise
|
||||||
|
action = np.clip(action, self.action_range[0], self.action_range[1])
|
||||||
|
return action, q, None, None
|
||||||
|
|
||||||
|
def store_transition(self, obs0, action, reward, obs1, terminal1):
|
||||||
|
reward *= self.reward_scale
|
||||||
|
self.memory.append(obs0, action, reward, obs1, terminal1)
|
||||||
|
if self.normalize_observations:
|
||||||
|
self.obs_rms.update(np.array([obs0]))
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
# Get a batch.
|
||||||
|
batch = self.memory.sample(batch_size=self.batch_size)
|
||||||
|
|
||||||
|
if self.normalize_returns and self.enable_popart:
|
||||||
|
old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict={
|
||||||
|
self.obs1: batch['obs1'],
|
||||||
|
self.rewards: batch['rewards'],
|
||||||
|
self.terminals1: batch['terminals1'].astype('float32'),
|
||||||
|
})
|
||||||
|
self.ret_rms.update(target_Q.flatten())
|
||||||
|
self.sess.run(self.renormalize_Q_outputs_op, feed_dict={
|
||||||
|
self.old_std : np.array([old_std]),
|
||||||
|
self.old_mean : np.array([old_mean]),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Run sanity check. Disabled by default since it slows down things considerably.
|
||||||
|
# print('running sanity check')
|
||||||
|
# target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
|
||||||
|
# self.obs1: batch['obs1'],
|
||||||
|
# self.rewards: batch['rewards'],
|
||||||
|
# self.terminals1: batch['terminals1'].astype('float32'),
|
||||||
|
# })
|
||||||
|
# print(target_Q_new, target_Q, new_mean, new_std)
|
||||||
|
# assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
|
||||||
|
else:
|
||||||
|
target_Q = self.sess.run(self.target_Q, feed_dict={
|
||||||
|
self.obs1: batch['obs1'],
|
||||||
|
self.rewards: batch['rewards'],
|
||||||
|
self.terminals1: batch['terminals1'].astype('float32'),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get all gradients and perform a synced update.
|
||||||
|
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
|
||||||
|
actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={
|
||||||
|
self.obs0: batch['obs0'],
|
||||||
|
self.actions: batch['actions'],
|
||||||
|
self.critic_target: target_Q,
|
||||||
|
})
|
||||||
|
self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
|
||||||
|
self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)
|
||||||
|
|
||||||
|
return critic_loss, actor_loss
|
||||||
|
|
||||||
|
def initialize(self, sess):
|
||||||
|
self.sess = sess
|
||||||
|
self.sess.run(tf.global_variables_initializer())
|
||||||
|
self.actor_optimizer.sync()
|
||||||
|
self.critic_optimizer.sync()
|
||||||
|
self.sess.run(self.target_init_updates)
|
||||||
|
|
||||||
|
def update_target_net(self):
|
||||||
|
self.sess.run(self.target_soft_updates)
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
if self.stats_sample is None:
|
||||||
|
# Get a sample and keep that fixed for all further computations.
|
||||||
|
# This allows us to estimate the change in value for the same set of inputs.
|
||||||
|
self.stats_sample = self.memory.sample(batch_size=self.batch_size)
|
||||||
|
values = self.sess.run(self.stats_ops, feed_dict={
|
||||||
|
self.obs0: self.stats_sample['obs0'],
|
||||||
|
self.actions: self.stats_sample['actions'],
|
||||||
|
})
|
||||||
|
|
||||||
|
names = self.stats_names[:]
|
||||||
|
assert len(names) == len(values)
|
||||||
|
stats = dict(zip(names, values))
|
||||||
|
|
||||||
|
if self.param_noise is not None:
|
||||||
|
stats = {**stats, **self.param_noise.get_stats()}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def adapt_param_noise(self):
|
||||||
|
if self.param_noise is None:
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
|
||||||
|
batch = self.memory.sample(batch_size=self.batch_size)
|
||||||
|
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
|
||||||
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
|
})
|
||||||
|
distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
|
||||||
|
self.obs0: batch['obs0'],
|
||||||
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
|
})
|
||||||
|
|
||||||
|
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||||
|
self.param_noise.adapt(mean_distance)
|
||||||
|
return mean_distance
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
# Reset internal state after an episode is complete.
|
||||||
|
if self.action_noise is not None:
|
||||||
|
self.action_noise.reset()
|
||||||
|
if self.param_noise is not None:
|
||||||
|
self.sess.run(self.perturb_policy_ops, feed_dict={
|
||||||
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
|
})
|
@@ -1,123 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from baselines import logger, bench
|
|
||||||
from baselines.common.misc_util import (
|
|
||||||
set_global_seeds,
|
|
||||||
boolean_flag,
|
|
||||||
)
|
|
||||||
import baselines.ddpg.training as training
|
|
||||||
from baselines.ddpg.models import Actor, Critic
|
|
||||||
from baselines.ddpg.memory import Memory
|
|
||||||
from baselines.ddpg.noise import *
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import tensorflow as tf
|
|
||||||
from mpi4py import MPI
|
|
||||||
|
|
||||||
def run(env_id, seed, noise_type, layer_norm, evaluation, **kwargs):
|
|
||||||
# Configure things.
|
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
|
||||||
if rank != 0:
|
|
||||||
logger.set_level(logger.DISABLED)
|
|
||||||
|
|
||||||
# Create envs.
|
|
||||||
env = gym.make(env_id)
|
|
||||||
env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
|
|
||||||
|
|
||||||
if evaluation and rank==0:
|
|
||||||
eval_env = gym.make(env_id)
|
|
||||||
eval_env = bench.Monitor(eval_env, os.path.join(logger.get_dir(), 'gym_eval'))
|
|
||||||
env = bench.Monitor(env, None)
|
|
||||||
else:
|
|
||||||
eval_env = None
|
|
||||||
|
|
||||||
# Parse noise_type
|
|
||||||
action_noise = None
|
|
||||||
param_noise = None
|
|
||||||
nb_actions = env.action_space.shape[-1]
|
|
||||||
for current_noise_type in noise_type.split(','):
|
|
||||||
current_noise_type = current_noise_type.strip()
|
|
||||||
if current_noise_type == 'none':
|
|
||||||
pass
|
|
||||||
elif 'adaptive-param' in current_noise_type:
|
|
||||||
_, stddev = current_noise_type.split('_')
|
|
||||||
param_noise = AdaptiveParamNoiseSpec(initial_stddev=float(stddev), desired_action_stddev=float(stddev))
|
|
||||||
elif 'normal' in current_noise_type:
|
|
||||||
_, stddev = current_noise_type.split('_')
|
|
||||||
action_noise = NormalActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
|
|
||||||
elif 'ou' in current_noise_type:
|
|
||||||
_, stddev = current_noise_type.split('_')
|
|
||||||
action_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
|
|
||||||
else:
|
|
||||||
raise RuntimeError('unknown noise type "{}"'.format(current_noise_type))
|
|
||||||
|
|
||||||
# Configure components.
|
|
||||||
memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape)
|
|
||||||
critic = Critic(layer_norm=layer_norm)
|
|
||||||
actor = Actor(nb_actions, layer_norm=layer_norm)
|
|
||||||
|
|
||||||
# Seed everything to make things reproducible.
|
|
||||||
seed = seed + 1000000 * rank
|
|
||||||
logger.info('rank {}: seed={}, logdir={}'.format(rank, seed, logger.get_dir()))
|
|
||||||
tf.reset_default_graph()
|
|
||||||
set_global_seeds(seed)
|
|
||||||
env.seed(seed)
|
|
||||||
if eval_env is not None:
|
|
||||||
eval_env.seed(seed)
|
|
||||||
|
|
||||||
# Disable logging for rank != 0 to avoid noise.
|
|
||||||
if rank == 0:
|
|
||||||
start_time = time.time()
|
|
||||||
training.train(env=env, eval_env=eval_env, param_noise=param_noise,
|
|
||||||
action_noise=action_noise, actor=actor, critic=critic, memory=memory, **kwargs)
|
|
||||||
env.close()
|
|
||||||
if eval_env is not None:
|
|
||||||
eval_env.close()
|
|
||||||
if rank == 0:
|
|
||||||
logger.info('total runtime: {}s'.format(time.time() - start_time))
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
|
|
||||||
parser.add_argument('--env-id', type=str, default='HalfCheetah-v1')
|
|
||||||
boolean_flag(parser, 'render-eval', default=False)
|
|
||||||
boolean_flag(parser, 'layer-norm', default=True)
|
|
||||||
boolean_flag(parser, 'render', default=False)
|
|
||||||
boolean_flag(parser, 'normalize-returns', default=False)
|
|
||||||
boolean_flag(parser, 'normalize-observations', default=True)
|
|
||||||
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
|
||||||
parser.add_argument('--critic-l2-reg', type=float, default=1e-2)
|
|
||||||
parser.add_argument('--batch-size', type=int, default=64) # per MPI worker
|
|
||||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
|
||||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
|
||||||
boolean_flag(parser, 'popart', default=False)
|
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
|
||||||
parser.add_argument('--reward-scale', type=float, default=1.)
|
|
||||||
parser.add_argument('--clip-norm', type=float, default=None)
|
|
||||||
parser.add_argument('--nb-epochs', type=int, default=500) # with default settings, perform 1M steps total
|
|
||||||
parser.add_argument('--nb-epoch-cycles', type=int, default=20)
|
|
||||||
parser.add_argument('--nb-train-steps', type=int, default=50) # per epoch cycle and MPI worker
|
|
||||||
parser.add_argument('--nb-eval-steps', type=int, default=100) # per epoch cycle and MPI worker
|
|
||||||
parser.add_argument('--nb-rollout-steps', type=int, default=100) # per epoch cycle and MPI worker
|
|
||||||
parser.add_argument('--noise-type', type=str, default='adaptive-param_0.2') # choices are adaptive-param_xx, ou_xx, normal_xx, none
|
|
||||||
parser.add_argument('--num-timesteps', type=int, default=None)
|
|
||||||
boolean_flag(parser, 'evaluation', default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
# we don't directly specify timesteps for this script, so make sure that if we do specify them
|
|
||||||
# they agree with the other parameters
|
|
||||||
if args.num_timesteps is not None:
|
|
||||||
assert(args.num_timesteps == args.nb_epochs * args.nb_epoch_cycles * args.nb_rollout_steps)
|
|
||||||
dict_args = vars(args)
|
|
||||||
del dict_args['num_timesteps']
|
|
||||||
return dict_args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = parse_args()
|
|
||||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
|
||||||
logger.configure()
|
|
||||||
# Run actual script.
|
|
||||||
run(**args)
|
|
@@ -1,10 +1,11 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.contrib as tc
|
from baselines.common.models import get_network_builder
|
||||||
|
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
def __init__(self, name):
|
def __init__(self, name, network='mlp', **network_kwargs):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.network_builder = get_network_builder(network)(**network_kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vars(self):
|
def vars(self):
|
||||||
@@ -20,54 +21,27 @@ class Model(object):
|
|||||||
|
|
||||||
|
|
||||||
class Actor(Model):
|
class Actor(Model):
|
||||||
def __init__(self, nb_actions, name='actor', layer_norm=True):
|
def __init__(self, nb_actions, name='actor', network='mlp', **network_kwargs):
|
||||||
super(Actor, self).__init__(name=name)
|
super().__init__(name=name, network=network, **network_kwargs)
|
||||||
self.nb_actions = nb_actions
|
self.nb_actions = nb_actions
|
||||||
self.layer_norm = layer_norm
|
|
||||||
|
|
||||||
def __call__(self, obs, reuse=False):
|
def __call__(self, obs, reuse=False):
|
||||||
with tf.variable_scope(self.name) as scope:
|
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
|
||||||
if reuse:
|
x = self.network_builder(obs)
|
||||||
scope.reuse_variables()
|
|
||||||
|
|
||||||
x = obs
|
|
||||||
x = tf.layers.dense(x, 64)
|
|
||||||
if self.layer_norm:
|
|
||||||
x = tc.layers.layer_norm(x, center=True, scale=True)
|
|
||||||
x = tf.nn.relu(x)
|
|
||||||
|
|
||||||
x = tf.layers.dense(x, 64)
|
|
||||||
if self.layer_norm:
|
|
||||||
x = tc.layers.layer_norm(x, center=True, scale=True)
|
|
||||||
x = tf.nn.relu(x)
|
|
||||||
|
|
||||||
x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
||||||
x = tf.nn.tanh(x)
|
x = tf.nn.tanh(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Critic(Model):
|
class Critic(Model):
|
||||||
def __init__(self, name='critic', layer_norm=True):
|
def __init__(self, name='critic', network='mlp', **network_kwargs):
|
||||||
super(Critic, self).__init__(name=name)
|
super().__init__(name=name, network=network, **network_kwargs)
|
||||||
self.layer_norm = layer_norm
|
self.layer_norm = True
|
||||||
|
|
||||||
def __call__(self, obs, action, reuse=False):
|
def __call__(self, obs, action, reuse=False):
|
||||||
with tf.variable_scope(self.name) as scope:
|
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
|
||||||
if reuse:
|
x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated
|
||||||
scope.reuse_variables()
|
x = self.network_builder(x)
|
||||||
|
|
||||||
x = obs
|
|
||||||
x = tf.layers.dense(x, 64)
|
|
||||||
if self.layer_norm:
|
|
||||||
x = tc.layers.layer_norm(x, center=True, scale=True)
|
|
||||||
x = tf.nn.relu(x)
|
|
||||||
|
|
||||||
x = tf.concat([x, action], axis=-1)
|
|
||||||
x = tf.layers.dense(x, 64)
|
|
||||||
if self.layer_norm:
|
|
||||||
x = tc.layers.layer_norm(x, center=True, scale=True)
|
|
||||||
x = tf.nn.relu(x)
|
|
||||||
|
|
||||||
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@@ -1,191 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
from collections import deque
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
from baselines.ddpg.ddpg import DDPG
|
|
||||||
import baselines.common.tf_util as U
|
|
||||||
|
|
||||||
from baselines import logger
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from mpi4py import MPI
|
|
||||||
|
|
||||||
|
|
||||||
def train(env, nb_epochs, nb_epoch_cycles, render_eval, reward_scale, render, param_noise, actor, critic,
|
|
||||||
normalize_returns, normalize_observations, critic_l2_reg, actor_lr, critic_lr, action_noise,
|
|
||||||
popart, gamma, clip_norm, nb_train_steps, nb_rollout_steps, nb_eval_steps, batch_size, memory,
|
|
||||||
tau=0.01, eval_env=None, param_noise_adaption_interval=50):
|
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
|
||||||
|
|
||||||
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
|
||||||
max_action = env.action_space.high
|
|
||||||
logger.info('scaling actions by {} before executing in env'.format(max_action))
|
|
||||||
agent = DDPG(actor, critic, memory, env.observation_space.shape, env.action_space.shape,
|
|
||||||
gamma=gamma, tau=tau, normalize_returns=normalize_returns, normalize_observations=normalize_observations,
|
|
||||||
batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
|
|
||||||
actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
|
|
||||||
reward_scale=reward_scale)
|
|
||||||
logger.info('Using agent with the following configuration:')
|
|
||||||
logger.info(str(agent.__dict__.items()))
|
|
||||||
|
|
||||||
# Set up logging stuff only for a single worker.
|
|
||||||
if rank == 0:
|
|
||||||
saver = tf.train.Saver()
|
|
||||||
else:
|
|
||||||
saver = None
|
|
||||||
|
|
||||||
step = 0
|
|
||||||
episode = 0
|
|
||||||
eval_episode_rewards_history = deque(maxlen=100)
|
|
||||||
episode_rewards_history = deque(maxlen=100)
|
|
||||||
with U.single_threaded_session() as sess:
|
|
||||||
# Prepare everything.
|
|
||||||
agent.initialize(sess)
|
|
||||||
sess.graph.finalize()
|
|
||||||
|
|
||||||
agent.reset()
|
|
||||||
obs = env.reset()
|
|
||||||
if eval_env is not None:
|
|
||||||
eval_obs = eval_env.reset()
|
|
||||||
done = False
|
|
||||||
episode_reward = 0.
|
|
||||||
episode_step = 0
|
|
||||||
episodes = 0
|
|
||||||
t = 0
|
|
||||||
|
|
||||||
epoch = 0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
epoch_episode_rewards = []
|
|
||||||
epoch_episode_steps = []
|
|
||||||
epoch_episode_eval_rewards = []
|
|
||||||
epoch_episode_eval_steps = []
|
|
||||||
epoch_start_time = time.time()
|
|
||||||
epoch_actions = []
|
|
||||||
epoch_qs = []
|
|
||||||
epoch_episodes = 0
|
|
||||||
for epoch in range(nb_epochs):
|
|
||||||
for cycle in range(nb_epoch_cycles):
|
|
||||||
# Perform rollouts.
|
|
||||||
for t_rollout in range(nb_rollout_steps):
|
|
||||||
# Predict next action.
|
|
||||||
action, q = agent.pi(obs, apply_noise=True, compute_Q=True)
|
|
||||||
assert action.shape == env.action_space.shape
|
|
||||||
|
|
||||||
# Execute next action.
|
|
||||||
if rank == 0 and render:
|
|
||||||
env.render()
|
|
||||||
assert max_action.shape == action.shape
|
|
||||||
new_obs, r, done, info = env.step(max_action * action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
|
||||||
t += 1
|
|
||||||
if rank == 0 and render:
|
|
||||||
env.render()
|
|
||||||
episode_reward += r
|
|
||||||
episode_step += 1
|
|
||||||
|
|
||||||
# Book-keeping.
|
|
||||||
epoch_actions.append(action)
|
|
||||||
epoch_qs.append(q)
|
|
||||||
agent.store_transition(obs, action, r, new_obs, done)
|
|
||||||
obs = new_obs
|
|
||||||
|
|
||||||
if done:
|
|
||||||
# Episode done.
|
|
||||||
epoch_episode_rewards.append(episode_reward)
|
|
||||||
episode_rewards_history.append(episode_reward)
|
|
||||||
epoch_episode_steps.append(episode_step)
|
|
||||||
episode_reward = 0.
|
|
||||||
episode_step = 0
|
|
||||||
epoch_episodes += 1
|
|
||||||
episodes += 1
|
|
||||||
|
|
||||||
agent.reset()
|
|
||||||
obs = env.reset()
|
|
||||||
|
|
||||||
# Train.
|
|
||||||
epoch_actor_losses = []
|
|
||||||
epoch_critic_losses = []
|
|
||||||
epoch_adaptive_distances = []
|
|
||||||
for t_train in range(nb_train_steps):
|
|
||||||
# Adapt param noise, if necessary.
|
|
||||||
if memory.nb_entries >= batch_size and t_train % param_noise_adaption_interval == 0:
|
|
||||||
distance = agent.adapt_param_noise()
|
|
||||||
epoch_adaptive_distances.append(distance)
|
|
||||||
|
|
||||||
cl, al = agent.train()
|
|
||||||
epoch_critic_losses.append(cl)
|
|
||||||
epoch_actor_losses.append(al)
|
|
||||||
agent.update_target_net()
|
|
||||||
|
|
||||||
# Evaluate.
|
|
||||||
eval_episode_rewards = []
|
|
||||||
eval_qs = []
|
|
||||||
if eval_env is not None:
|
|
||||||
eval_episode_reward = 0.
|
|
||||||
for t_rollout in range(nb_eval_steps):
|
|
||||||
eval_action, eval_q = agent.pi(eval_obs, apply_noise=False, compute_Q=True)
|
|
||||||
eval_obs, eval_r, eval_done, eval_info = eval_env.step(max_action * eval_action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
|
||||||
if render_eval:
|
|
||||||
eval_env.render()
|
|
||||||
eval_episode_reward += eval_r
|
|
||||||
|
|
||||||
eval_qs.append(eval_q)
|
|
||||||
if eval_done:
|
|
||||||
eval_obs = eval_env.reset()
|
|
||||||
eval_episode_rewards.append(eval_episode_reward)
|
|
||||||
eval_episode_rewards_history.append(eval_episode_reward)
|
|
||||||
eval_episode_reward = 0.
|
|
||||||
|
|
||||||
mpi_size = MPI.COMM_WORLD.Get_size()
|
|
||||||
# Log stats.
|
|
||||||
# XXX shouldn't call np.mean on variable length lists
|
|
||||||
duration = time.time() - start_time
|
|
||||||
stats = agent.get_stats()
|
|
||||||
combined_stats = stats.copy()
|
|
||||||
combined_stats['rollout/return'] = np.mean(epoch_episode_rewards)
|
|
||||||
combined_stats['rollout/return_history'] = np.mean(episode_rewards_history)
|
|
||||||
combined_stats['rollout/episode_steps'] = np.mean(epoch_episode_steps)
|
|
||||||
combined_stats['rollout/actions_mean'] = np.mean(epoch_actions)
|
|
||||||
combined_stats['rollout/Q_mean'] = np.mean(epoch_qs)
|
|
||||||
combined_stats['train/loss_actor'] = np.mean(epoch_actor_losses)
|
|
||||||
combined_stats['train/loss_critic'] = np.mean(epoch_critic_losses)
|
|
||||||
combined_stats['train/param_noise_distance'] = np.mean(epoch_adaptive_distances)
|
|
||||||
combined_stats['total/duration'] = duration
|
|
||||||
combined_stats['total/steps_per_second'] = float(t) / float(duration)
|
|
||||||
combined_stats['total/episodes'] = episodes
|
|
||||||
combined_stats['rollout/episodes'] = epoch_episodes
|
|
||||||
combined_stats['rollout/actions_std'] = np.std(epoch_actions)
|
|
||||||
# Evaluation statistics.
|
|
||||||
if eval_env is not None:
|
|
||||||
combined_stats['eval/return'] = eval_episode_rewards
|
|
||||||
combined_stats['eval/return_history'] = np.mean(eval_episode_rewards_history)
|
|
||||||
combined_stats['eval/Q'] = eval_qs
|
|
||||||
combined_stats['eval/episodes'] = len(eval_episode_rewards)
|
|
||||||
def as_scalar(x):
|
|
||||||
if isinstance(x, np.ndarray):
|
|
||||||
assert x.size == 1
|
|
||||||
return x[0]
|
|
||||||
elif np.isscalar(x):
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
raise ValueError('expected scalar, got %s'%x)
|
|
||||||
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([as_scalar(x) for x in combined_stats.values()]))
|
|
||||||
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
|
||||||
|
|
||||||
# Total statistics.
|
|
||||||
combined_stats['total/epochs'] = epoch + 1
|
|
||||||
combined_stats['total/steps'] = t
|
|
||||||
|
|
||||||
for key in sorted(combined_stats.keys()):
|
|
||||||
logger.record_tabular(key, combined_stats[key])
|
|
||||||
logger.dump_tabular()
|
|
||||||
logger.info('')
|
|
||||||
logdir = logger.get_dir()
|
|
||||||
if rank == 0 and logdir:
|
|
||||||
if hasattr(env, 'get_state'):
|
|
||||||
with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as f:
|
|
||||||
pickle.dump(env.get_state(), f)
|
|
||||||
if eval_env and hasattr(eval_env, 'get_state'):
|
|
||||||
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
|
|
||||||
pickle.dump(eval_env.get_state(), f)
|
|
@@ -98,7 +98,12 @@ def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **netwo
|
|||||||
|
|
||||||
def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
|
def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
|
||||||
with tf.variable_scope(scope, reuse=reuse):
|
with tf.variable_scope(scope, reuse=reuse):
|
||||||
latent, _ = network(input_placeholder)
|
latent = network(input_placeholder)
|
||||||
|
if isinstance(latent, tuple):
|
||||||
|
if latent[1] is not None:
|
||||||
|
raise NotImplementedError("DQN is not compatible with recurrent policies yet")
|
||||||
|
latent = latent[0]
|
||||||
|
|
||||||
latent = layers.flatten(latent)
|
latent = layers.flatten(latent)
|
||||||
|
|
||||||
with tf.variable_scope("action_value"):
|
with tf.variable_scope("action_value"):
|
||||||
|
Reference in New Issue
Block a user