Compare commits
3 Commits
master
...
param-nois
Author | SHA1 | Date | |
---|---|---|---|
|
45a1297ec0 | ||
|
6964729ed0 | ||
|
13c4107264 |
@@ -15,3 +15,4 @@ pip install baselines
|
|||||||
- [DQN](baselines/deepq)
|
- [DQN](baselines/deepq)
|
||||||
- [PPO](baselines/pposgd)
|
- [PPO](baselines/pposgd)
|
||||||
- [TRPO](baselines/trpo_mpi)
|
- [TRPO](baselines/trpo_mpi)
|
||||||
|
- [DDPG](baselines/ddpg)
|
||||||
|
@@ -3,7 +3,10 @@ import tempfile
|
|||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
from azure.common import AzureMissingResourceHttpError
|
from azure.common import AzureMissingResourceHttpError
|
||||||
from azure.storage.blob import BlobService
|
try:
|
||||||
|
from azure.storage.blob import BlobService
|
||||||
|
except ImportError:
|
||||||
|
from azure.storage.blob import BlockBlobService as BlobService
|
||||||
from shutil import unpack_archive
|
from shutil import unpack_archive
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
|
||||||
@@ -114,18 +117,23 @@ class Container(object):
|
|||||||
arcpath = os.path.join(td, "archive.zip")
|
arcpath = os.path.join(td, "archive.zip")
|
||||||
for backup_blob_name in [blob_name, blob_name + '.backup']:
|
for backup_blob_name in [blob_name, blob_name + '.backup']:
|
||||||
try:
|
try:
|
||||||
blob_size = self._service.get_blob_properties(
|
properties = self._service.get_blob_properties(
|
||||||
blob_name=backup_blob_name,
|
blob_name=backup_blob_name,
|
||||||
container_name=self._container_name
|
container_name=self._container_name
|
||||||
)['content-length']
|
)
|
||||||
|
if hasattr(properties, 'properties'):
|
||||||
|
# Annoyingly, Azure has changed the API and this now returns a blob
|
||||||
|
# instead of it's properties with up-to-date azure package.
|
||||||
|
blob_size = properties.properties.content_length
|
||||||
|
else:
|
||||||
|
blob_size = properties['content-length']
|
||||||
if int(blob_size) > 0:
|
if int(blob_size) > 0:
|
||||||
self._service.get_blob_to_path(
|
self._service.get_blob_to_path(
|
||||||
container_name=self._container_name,
|
container_name=self._container_name,
|
||||||
blob_name=backup_blob_name,
|
blob_name=backup_blob_name,
|
||||||
file_path=arcpath,
|
file_path=arcpath,
|
||||||
max_connections=4,
|
max_connections=4,
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback)
|
||||||
max_retries=10)
|
|
||||||
unpack_archive(arcpath, dest_path)
|
unpack_archive(arcpath, dest_path)
|
||||||
download_done.wait()
|
download_done.wait()
|
||||||
return True
|
return True
|
||||||
|
@@ -237,8 +237,9 @@ def boolean_flag(parser, name, default=False, help=None):
|
|||||||
help: str
|
help: str
|
||||||
help string for the flag
|
help string for the flag
|
||||||
"""
|
"""
|
||||||
parser.add_argument("--" + name, action="store_true", default=default, help=help)
|
dest = name.replace('-', '_')
|
||||||
parser.add_argument("--no-" + name, action="store_false", dest=name)
|
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
|
||||||
|
parser.add_argument("--no-" + name, action="store_false", dest=dest)
|
||||||
|
|
||||||
|
|
||||||
def get_wrapper_by_name(env, classname):
|
def get_wrapper_by_name(env, classname):
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import os, subprocess, sys
|
import os, subprocess, sys
|
||||||
|
|
||||||
def mpi_fork(n):
|
def mpi_fork(n, bind_to_core=False):
|
||||||
"""Re-launches the current script with workers
|
"""Re-launches the current script with workers
|
||||||
Returns "parent" for original parent, "child" for MPI children
|
Returns "parent" for original parent, "child" for MPI children
|
||||||
"""
|
"""
|
||||||
@@ -13,7 +13,11 @@ def mpi_fork(n):
|
|||||||
OMP_NUM_THREADS="1",
|
OMP_NUM_THREADS="1",
|
||||||
IN_MPI="1"
|
IN_MPI="1"
|
||||||
)
|
)
|
||||||
subprocess.check_call(["mpirun", "-np", str(n), sys.executable] + sys.argv, env=env)
|
args = ["mpirun", "-np", str(n)]
|
||||||
|
if bind_to_core:
|
||||||
|
args += ["-bind-to", "core"]
|
||||||
|
args += [sys.executable] + sys.argv
|
||||||
|
subprocess.check_call(args, env=env)
|
||||||
return "parent"
|
return "parent"
|
||||||
else:
|
else:
|
||||||
return "child"
|
return "child"
|
||||||
|
0
baselines/ddpg/__init__.py
Normal file
0
baselines/ddpg/__init__.py
Normal file
372
baselines/ddpg/ddpg.py
Normal file
372
baselines/ddpg/ddpg.py
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
from copy import copy
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.contrib as tc
|
||||||
|
|
||||||
|
from baselines import logger
|
||||||
|
from baselines.common.mpi_adam import MpiAdam
|
||||||
|
import baselines.common.tf_util as U
|
||||||
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||||
|
from baselines.ddpg.util import reduce_std, mpi_mean
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x, stats):
|
||||||
|
if stats is None:
|
||||||
|
return x
|
||||||
|
return (x - stats.mean) / stats.std
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize(x, stats):
|
||||||
|
if stats is None:
|
||||||
|
return x
|
||||||
|
return x * stats.std + stats.mean
|
||||||
|
|
||||||
|
|
||||||
|
def get_target_updates(vars, target_vars, tau):
|
||||||
|
logger.info('setting up target updates ...')
|
||||||
|
soft_updates = []
|
||||||
|
init_updates = []
|
||||||
|
assert len(vars) == len(target_vars)
|
||||||
|
for var, target_var in zip(vars, target_vars):
|
||||||
|
logger.info(' {} <- {}'.format(target_var.name, var.name))
|
||||||
|
init_updates.append(tf.assign(target_var, var))
|
||||||
|
soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var))
|
||||||
|
assert len(init_updates) == len(vars)
|
||||||
|
assert len(soft_updates) == len(vars)
|
||||||
|
return tf.group(*init_updates), tf.group(*soft_updates)
|
||||||
|
|
||||||
|
|
||||||
|
def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev):
|
||||||
|
assert len(actor.vars) == len(perturbed_actor.vars)
|
||||||
|
assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
for var, perturbed_var in zip(actor.vars, perturbed_actor.vars):
|
||||||
|
if var in actor.perturbable_vars:
|
||||||
|
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
|
||||||
|
updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
|
||||||
|
else:
|
||||||
|
logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
|
||||||
|
updates.append(tf.assign(perturbed_var, var))
|
||||||
|
assert len(updates) == len(actor.vars)
|
||||||
|
return tf.group(*updates)
|
||||||
|
|
||||||
|
|
||||||
|
class DDPG(object):
|
||||||
|
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
|
||||||
|
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
|
||||||
|
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
|
||||||
|
adaptive_param_noise=True, adaptive_param_noise_policy_threshold=.1,
|
||||||
|
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
|
||||||
|
# Inputs.
|
||||||
|
self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
|
||||||
|
self.obs1 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs1')
|
||||||
|
self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
|
||||||
|
self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
|
||||||
|
self.actions = tf.placeholder(tf.float32, shape=(None,) + action_shape, name='actions')
|
||||||
|
self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
|
||||||
|
self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
self.gamma = gamma
|
||||||
|
self.tau = tau
|
||||||
|
self.memory = memory
|
||||||
|
self.normalize_observations = normalize_observations
|
||||||
|
self.normalize_returns = normalize_returns
|
||||||
|
self.action_noise = action_noise
|
||||||
|
self.param_noise = param_noise
|
||||||
|
self.action_range = action_range
|
||||||
|
self.return_range = return_range
|
||||||
|
self.observation_range = observation_range
|
||||||
|
self.critic = critic
|
||||||
|
self.actor = actor
|
||||||
|
self.actor_lr = actor_lr
|
||||||
|
self.critic_lr = critic_lr
|
||||||
|
self.clip_norm = clip_norm
|
||||||
|
self.enable_popart = enable_popart
|
||||||
|
self.reward_scale = reward_scale
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.stats_sample = None
|
||||||
|
self.critic_l2_reg = critic_l2_reg
|
||||||
|
|
||||||
|
# Observation normalization.
|
||||||
|
if self.normalize_observations:
|
||||||
|
with tf.variable_scope('obs_rms'):
|
||||||
|
self.obs_rms = RunningMeanStd(shape=observation_shape)
|
||||||
|
else:
|
||||||
|
self.obs_rms = None
|
||||||
|
normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms),
|
||||||
|
self.observation_range[0], self.observation_range[1])
|
||||||
|
normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms),
|
||||||
|
self.observation_range[0], self.observation_range[1])
|
||||||
|
|
||||||
|
# Return normalization.
|
||||||
|
if self.normalize_returns:
|
||||||
|
with tf.variable_scope('ret_rms'):
|
||||||
|
self.ret_rms = RunningMeanStd()
|
||||||
|
else:
|
||||||
|
self.ret_rms = None
|
||||||
|
|
||||||
|
# Create target networks.
|
||||||
|
target_actor = copy(actor)
|
||||||
|
target_actor.name = 'target_actor'
|
||||||
|
self.target_actor = target_actor
|
||||||
|
target_critic = copy(critic)
|
||||||
|
target_critic.name = 'target_critic'
|
||||||
|
self.target_critic = target_critic
|
||||||
|
|
||||||
|
# Create networks and core TF parts that are shared across setup parts.
|
||||||
|
self.actor_tf = actor(normalized_obs0)
|
||||||
|
self.normalized_critic_tf = critic(normalized_obs0, self.actions)
|
||||||
|
self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
||||||
|
self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True)
|
||||||
|
self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
|
||||||
|
Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)
|
||||||
|
self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1
|
||||||
|
|
||||||
|
# Set up parts.
|
||||||
|
if self.param_noise is not None:
|
||||||
|
self.setup_param_noise(normalized_obs0)
|
||||||
|
self.setup_actor_optimizer()
|
||||||
|
self.setup_critic_optimizer()
|
||||||
|
if self.normalize_returns and self.enable_popart:
|
||||||
|
self.setup_popart()
|
||||||
|
self.setup_stats()
|
||||||
|
self.setup_target_network_updates()
|
||||||
|
|
||||||
|
def setup_target_network_updates(self):
|
||||||
|
actor_init_updates, actor_soft_updates = get_target_updates(self.actor.vars, self.target_actor.vars, self.tau)
|
||||||
|
critic_init_updates, critic_soft_updates = get_target_updates(self.critic.vars, self.target_critic.vars, self.tau)
|
||||||
|
self.target_init_updates = [actor_init_updates, critic_init_updates]
|
||||||
|
self.target_soft_updates = [actor_soft_updates, critic_soft_updates]
|
||||||
|
|
||||||
|
def setup_param_noise(self, normalized_obs0):
|
||||||
|
assert self.param_noise is not None
|
||||||
|
|
||||||
|
# Configure perturbed actor.
|
||||||
|
param_noise_actor = copy(self.actor)
|
||||||
|
param_noise_actor.name = 'param_noise_actor'
|
||||||
|
self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
|
||||||
|
logger.info('setting up param noise')
|
||||||
|
self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)
|
||||||
|
|
||||||
|
# Configure separate copy for stddev adoption.
|
||||||
|
adaptive_param_noise_actor = copy(self.actor)
|
||||||
|
adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
|
||||||
|
adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
|
||||||
|
self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
|
||||||
|
self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
|
||||||
|
|
||||||
|
def setup_actor_optimizer(self):
|
||||||
|
logger.info('setting up actor optimizer')
|
||||||
|
self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
|
||||||
|
actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars]
|
||||||
|
actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
|
||||||
|
logger.info(' actor shapes: {}'.format(actor_shapes))
|
||||||
|
logger.info(' actor params: {}'.format(actor_nb_params))
|
||||||
|
self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm)
|
||||||
|
self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
|
||||||
|
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
||||||
|
|
||||||
|
def setup_critic_optimizer(self):
|
||||||
|
logger.info('setting up critic optimizer')
|
||||||
|
normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
|
||||||
|
self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
|
||||||
|
if self.critic_l2_reg > 0.:
|
||||||
|
critic_reg_vars = [var for var in self.critic.trainable_vars if 'kernel' in var.name and 'output' not in var.name]
|
||||||
|
for var in critic_reg_vars:
|
||||||
|
logger.info(' regularizing: {}'.format(var.name))
|
||||||
|
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
|
||||||
|
critic_reg = tc.layers.apply_regularization(
|
||||||
|
tc.layers.l2_regularizer(self.critic_l2_reg),
|
||||||
|
weights_list=critic_reg_vars
|
||||||
|
)
|
||||||
|
self.critic_loss += critic_reg
|
||||||
|
critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars]
|
||||||
|
critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
|
||||||
|
logger.info(' critic shapes: {}'.format(critic_shapes))
|
||||||
|
logger.info(' critic params: {}'.format(critic_nb_params))
|
||||||
|
self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm)
|
||||||
|
self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
|
||||||
|
beta1=0.9, beta2=0.999, epsilon=1e-08)
|
||||||
|
|
||||||
|
def setup_popart(self):
|
||||||
|
# See https://arxiv.org/pdf/1602.07714.pdf for details.
|
||||||
|
self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
|
||||||
|
new_std = self.ret_rms.std
|
||||||
|
self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
|
||||||
|
new_mean = self.ret_rms.mean
|
||||||
|
|
||||||
|
self.renormalize_Q_outputs_op = []
|
||||||
|
for vs in [self.critic.output_vars, self.target_critic.output_vars]:
|
||||||
|
assert len(vs) == 2
|
||||||
|
M, b = vs
|
||||||
|
assert 'kernel' in M.name
|
||||||
|
assert 'bias' in b.name
|
||||||
|
assert M.get_shape()[-1] == 1
|
||||||
|
assert b.get_shape()[-1] == 1
|
||||||
|
self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)]
|
||||||
|
self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)]
|
||||||
|
|
||||||
|
def setup_stats(self):
|
||||||
|
ops = []
|
||||||
|
names = []
|
||||||
|
|
||||||
|
if self.normalize_returns:
|
||||||
|
ops += [self.ret_rms.mean, self.ret_rms.std]
|
||||||
|
names += ['ret_rms_mean', 'ret_rms_std']
|
||||||
|
|
||||||
|
if self.normalize_observations:
|
||||||
|
ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
|
||||||
|
names += ['obs_rms_mean', 'obs_rms_std']
|
||||||
|
|
||||||
|
ops += [tf.reduce_mean(self.critic_tf)]
|
||||||
|
names += ['reference_Q_mean']
|
||||||
|
ops += [reduce_std(self.critic_tf)]
|
||||||
|
names += ['reference_Q_std']
|
||||||
|
|
||||||
|
ops += [tf.reduce_mean(self.critic_with_actor_tf)]
|
||||||
|
names += ['reference_actor_Q_mean']
|
||||||
|
ops += [reduce_std(self.critic_with_actor_tf)]
|
||||||
|
names += ['reference_actor_Q_std']
|
||||||
|
|
||||||
|
ops += [tf.reduce_mean(self.actor_tf)]
|
||||||
|
names += ['reference_action_mean']
|
||||||
|
ops += [reduce_std(self.actor_tf)]
|
||||||
|
names += ['reference_action_std']
|
||||||
|
|
||||||
|
if self.param_noise:
|
||||||
|
ops += [tf.reduce_mean(self.perturbed_actor_tf)]
|
||||||
|
names += ['reference_perturbed_action_mean']
|
||||||
|
ops += [reduce_std(self.perturbed_actor_tf)]
|
||||||
|
names += ['reference_perturbed_action_std']
|
||||||
|
|
||||||
|
self.stats_ops = ops
|
||||||
|
self.stats_names = names
|
||||||
|
|
||||||
|
def pi(self, obs, apply_noise=True, compute_Q=True):
|
||||||
|
if self.param_noise is not None and apply_noise:
|
||||||
|
actor_tf = self.perturbed_actor_tf
|
||||||
|
else:
|
||||||
|
actor_tf = self.actor_tf
|
||||||
|
feed_dict = {self.obs0: [obs]}
|
||||||
|
if compute_Q:
|
||||||
|
action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
|
||||||
|
else:
|
||||||
|
action = self.sess.run(actor_tf, feed_dict=feed_dict)
|
||||||
|
q = None
|
||||||
|
action = action.flatten()
|
||||||
|
if self.action_noise is not None and apply_noise:
|
||||||
|
noise = self.action_noise()
|
||||||
|
assert noise.shape == action.shape
|
||||||
|
action += noise
|
||||||
|
action = np.clip(action, self.action_range[0], self.action_range[1])
|
||||||
|
return action, q
|
||||||
|
|
||||||
|
def store_transition(self, obs0, action, reward, obs1, terminal1):
|
||||||
|
reward *= self.reward_scale
|
||||||
|
self.memory.append(obs0, action, reward, obs1, terminal1)
|
||||||
|
if self.normalize_observations:
|
||||||
|
self.obs_rms.update(np.array([obs0]))
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
# Get a batch.
|
||||||
|
batch = self.memory.sample(batch_size=self.batch_size)
|
||||||
|
|
||||||
|
if self.normalize_returns and self.enable_popart:
|
||||||
|
old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict={
|
||||||
|
self.obs1: batch['obs1'],
|
||||||
|
self.rewards: batch['rewards'],
|
||||||
|
self.terminals1: batch['terminals1'].astype('float32'),
|
||||||
|
})
|
||||||
|
self.ret_rms.update(target_Q.flatten())
|
||||||
|
self.sess.run(self.renormalize_Q_outputs_op, feed_dict={
|
||||||
|
self.old_std : np.array([old_std]),
|
||||||
|
self.old_mean : np.array([old_mean]),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Run sanity check. Disabled by default since it slows down things considerably.
|
||||||
|
# print('running sanity check')
|
||||||
|
# target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
|
||||||
|
# self.obs1: batch['obs1'],
|
||||||
|
# self.rewards: batch['rewards'],
|
||||||
|
# self.terminals1: batch['terminals1'].astype('float32'),
|
||||||
|
# })
|
||||||
|
# print(target_Q_new, target_Q, new_mean, new_std)
|
||||||
|
# assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
|
||||||
|
else:
|
||||||
|
target_Q = self.sess.run(self.target_Q, feed_dict={
|
||||||
|
self.obs1: batch['obs1'],
|
||||||
|
self.rewards: batch['rewards'],
|
||||||
|
self.terminals1: batch['terminals1'].astype('float32'),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get all gradients and perform a synced update.
|
||||||
|
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
|
||||||
|
actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={
|
||||||
|
self.obs0: batch['obs0'],
|
||||||
|
self.actions: batch['actions'],
|
||||||
|
self.critic_target: target_Q,
|
||||||
|
})
|
||||||
|
self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
|
||||||
|
self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)
|
||||||
|
|
||||||
|
return critic_loss, actor_loss
|
||||||
|
|
||||||
|
def initialize(self, sess):
|
||||||
|
self.sess = sess
|
||||||
|
self.sess.run(tf.global_variables_initializer())
|
||||||
|
self.actor_optimizer.sync()
|
||||||
|
self.critic_optimizer.sync()
|
||||||
|
self.sess.run(self.target_init_updates)
|
||||||
|
|
||||||
|
def update_target_net(self):
|
||||||
|
self.sess.run(self.target_soft_updates)
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
if self.stats_sample is None:
|
||||||
|
# Get a sample and keep that fixed for all further computations.
|
||||||
|
# This allows us to estimate the change in value for the same set of inputs.
|
||||||
|
self.stats_sample = self.memory.sample(batch_size=self.batch_size)
|
||||||
|
values = self.sess.run(self.stats_ops, feed_dict={
|
||||||
|
self.obs0: self.stats_sample['obs0'],
|
||||||
|
self.actions: self.stats_sample['actions'],
|
||||||
|
})
|
||||||
|
|
||||||
|
names = self.stats_names[:]
|
||||||
|
assert len(names) == len(values)
|
||||||
|
stats = dict(zip(names, values))
|
||||||
|
|
||||||
|
if self.param_noise is not None:
|
||||||
|
stats = {**stats, **self.param_noise.get_stats()}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def adapt_param_noise(self):
|
||||||
|
if self.param_noise is None:
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
|
||||||
|
batch = self.memory.sample(batch_size=self.batch_size)
|
||||||
|
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
|
||||||
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
|
})
|
||||||
|
distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
|
||||||
|
self.obs0: batch['obs0'],
|
||||||
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
|
})
|
||||||
|
|
||||||
|
mean_distance = mpi_mean(distance)
|
||||||
|
self.param_noise.adapt(mean_distance)
|
||||||
|
return mean_distance
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
# Reset internal state after an episode is complete.
|
||||||
|
if self.action_noise is not None:
|
||||||
|
self.action_noise.reset()
|
||||||
|
if self.param_noise is not None:
|
||||||
|
self.sess.run(self.perturb_policy_ops, feed_dict={
|
||||||
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
|
})
|
161
baselines/ddpg/main.py
Normal file
161
baselines/ddpg/main.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from tempfile import mkdtemp
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
|
||||||
|
from baselines.common.mpi_fork import mpi_fork
|
||||||
|
from baselines import logger
|
||||||
|
from baselines.logger import Logger
|
||||||
|
from baselines.common.misc_util import (
|
||||||
|
set_global_seeds,
|
||||||
|
boolean_flag,
|
||||||
|
SimpleMonitor
|
||||||
|
)
|
||||||
|
import baselines.ddpg.training as training
|
||||||
|
from baselines.ddpg.models import Actor, Critic
|
||||||
|
from baselines.ddpg.memory import Memory
|
||||||
|
from baselines.ddpg.noise import *
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import tensorflow as tf
|
||||||
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
|
||||||
|
def run(env_id, seed, noise_type, num_cpu, layer_norm, logdir, gym_monitor, evaluation, bind_to_core, **kwargs):
|
||||||
|
kwargs['logdir'] = logdir
|
||||||
|
whoami = mpi_fork(num_cpu, bind_to_core=bind_to_core)
|
||||||
|
if whoami == 'parent':
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Configure things.
|
||||||
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
if rank != 0:
|
||||||
|
# Write to temp directory for all non-master workers.
|
||||||
|
actual_dir = None
|
||||||
|
Logger.CURRENT.close()
|
||||||
|
Logger.CURRENT = Logger(dir=mkdtemp(), output_formats=[])
|
||||||
|
logger.set_level(logger.DISABLED)
|
||||||
|
|
||||||
|
# Create envs.
|
||||||
|
if rank == 0:
|
||||||
|
env = gym.make(env_id)
|
||||||
|
if gym_monitor and logdir:
|
||||||
|
env = gym.wrappers.Monitor(env, os.path.join(logdir, 'gym_train'), force=True)
|
||||||
|
env = SimpleMonitor(env)
|
||||||
|
|
||||||
|
if evaluation:
|
||||||
|
eval_env = gym.make(env_id)
|
||||||
|
if gym_monitor and logdir:
|
||||||
|
eval_env = gym.wrappers.Monitor(eval_env, os.path.join(logdir, 'gym_eval'), force=True)
|
||||||
|
eval_env = SimpleMonitor(eval_env)
|
||||||
|
else:
|
||||||
|
eval_env = None
|
||||||
|
else:
|
||||||
|
env = gym.make(env_id)
|
||||||
|
if evaluation:
|
||||||
|
eval_env = gym.make(env_id)
|
||||||
|
else:
|
||||||
|
eval_env = None
|
||||||
|
|
||||||
|
# Parse noise_type
|
||||||
|
action_noise = None
|
||||||
|
param_noise = None
|
||||||
|
nb_actions = env.action_space.shape[-1]
|
||||||
|
for current_noise_type in noise_type.split(','):
|
||||||
|
current_noise_type = current_noise_type.strip()
|
||||||
|
if current_noise_type == 'none':
|
||||||
|
pass
|
||||||
|
elif 'adaptive-param' in current_noise_type:
|
||||||
|
_, stddev = current_noise_type.split('_')
|
||||||
|
param_noise = AdaptiveParamNoiseSpec(initial_stddev=float(stddev), desired_action_stddev=float(stddev))
|
||||||
|
elif 'normal' in current_noise_type:
|
||||||
|
_, stddev = current_noise_type.split('_')
|
||||||
|
action_noise = NormalActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
|
||||||
|
elif 'ou' in current_noise_type:
|
||||||
|
_, stddev = current_noise_type.split('_')
|
||||||
|
action_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
|
||||||
|
else:
|
||||||
|
raise RuntimeError('unknown noise type "{}"'.format(current_noise_type))
|
||||||
|
|
||||||
|
# Configure components.
|
||||||
|
memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape)
|
||||||
|
critic = Critic(layer_norm=layer_norm)
|
||||||
|
actor = Actor(nb_actions, layer_norm=layer_norm)
|
||||||
|
|
||||||
|
# Seed everything to make things reproducible.
|
||||||
|
seed = seed + 1000000 * rank
|
||||||
|
logger.info('rank {}: seed={}, logdir={}'.format(rank, seed, logger.get_dir()))
|
||||||
|
tf.reset_default_graph()
|
||||||
|
set_global_seeds(seed)
|
||||||
|
env.seed(seed)
|
||||||
|
if eval_env is not None:
|
||||||
|
eval_env.seed(seed)
|
||||||
|
|
||||||
|
# Disable logging for rank != 0 to avoid noise.
|
||||||
|
if rank == 0:
|
||||||
|
start_time = time.time()
|
||||||
|
training.train(env=env, eval_env=eval_env, param_noise=param_noise,
|
||||||
|
action_noise=action_noise, actor=actor, critic=critic, memory=memory, **kwargs)
|
||||||
|
env.close()
|
||||||
|
if eval_env is not None:
|
||||||
|
eval_env.close()
|
||||||
|
Logger.CURRENT.close()
|
||||||
|
if rank == 0:
|
||||||
|
logger.info('total runtime: {}s'.format(time.time() - start_time))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('--env-id', type=str, default='HalfCheetah-v1')
|
||||||
|
boolean_flag(parser, 'render-eval', default=False)
|
||||||
|
boolean_flag(parser, 'layer-norm', default=True)
|
||||||
|
boolean_flag(parser, 'render', default=False)
|
||||||
|
parser.add_argument('--num-cpu', type=int, default=1)
|
||||||
|
boolean_flag(parser, 'normalize-returns', default=False)
|
||||||
|
boolean_flag(parser, 'normalize-observations', default=True)
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--critic-l2-reg', type=float, default=1e-2)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64) # per MPI worker
|
||||||
|
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
boolean_flag(parser, 'popart', default=False)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--reward-scale', type=float, default=1.)
|
||||||
|
parser.add_argument('--clip-norm', type=float, default=None)
|
||||||
|
parser.add_argument('--nb-epochs', type=int, default=500) # with default settings, perform 1M steps total
|
||||||
|
parser.add_argument('--nb-epoch-cycles', type=int, default=20)
|
||||||
|
parser.add_argument('--nb-train-steps', type=int, default=50) # per epoch cycle and MPI worker
|
||||||
|
parser.add_argument('--nb-eval-steps', type=int, default=100) # per epoch cycle and MPI worker
|
||||||
|
parser.add_argument('--nb-rollout-steps', type=int, default=100) # per epoch cycle and MPI worker
|
||||||
|
parser.add_argument('--noise-type', type=str, default='adaptive-param_0.2') # choices are adaptive-param_xx, ou_xx, normal_xx, none
|
||||||
|
parser.add_argument('--logdir', type=str, default=None)
|
||||||
|
boolean_flag(parser, 'gym-monitor', default=False)
|
||||||
|
boolean_flag(parser, 'evaluation', default=True)
|
||||||
|
boolean_flag(parser, 'bind-to-core', default=False)
|
||||||
|
|
||||||
|
return vars(parser.parse_args())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Figure out what logdir to use.
|
||||||
|
if args['logdir'] is None:
|
||||||
|
args['logdir'] = os.getenv('OPENAI_LOGDIR')
|
||||||
|
|
||||||
|
# Print and save arguments.
|
||||||
|
logger.info('Arguments:')
|
||||||
|
for key in sorted(args.keys()):
|
||||||
|
logger.info('{}: {}'.format(key, args[key]))
|
||||||
|
logger.info('')
|
||||||
|
if args['logdir']:
|
||||||
|
with open(os.path.join(args['logdir'], 'args.json'), 'w') as f:
|
||||||
|
json.dump(args, f)
|
||||||
|
|
||||||
|
# Run actual script.
|
||||||
|
run(**args)
|
83
baselines/ddpg/memory.py
Normal file
83
baselines/ddpg/memory.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class RingBuffer(object):
|
||||||
|
def __init__(self, maxlen, shape, dtype='float32'):
|
||||||
|
self.maxlen = maxlen
|
||||||
|
self.start = 0
|
||||||
|
self.length = 0
|
||||||
|
self.data = np.zeros((maxlen,) + shape).astype(dtype)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if idx < 0 or idx >= self.length:
|
||||||
|
raise KeyError()
|
||||||
|
return self.data[(self.start + idx) % self.maxlen]
|
||||||
|
|
||||||
|
def get_batch(self, idxs):
|
||||||
|
return self.data[(self.start + idxs) % self.maxlen]
|
||||||
|
|
||||||
|
def append(self, v):
|
||||||
|
if self.length < self.maxlen:
|
||||||
|
# We have space, simply increase the length.
|
||||||
|
self.length += 1
|
||||||
|
elif self.length == self.maxlen:
|
||||||
|
# No space, "remove" the first item.
|
||||||
|
self.start = (self.start + 1) % self.maxlen
|
||||||
|
else:
|
||||||
|
# This should never happen.
|
||||||
|
raise RuntimeError()
|
||||||
|
self.data[(self.start + self.length - 1) % self.maxlen] = v
|
||||||
|
|
||||||
|
|
||||||
|
def array_min2d(x):
|
||||||
|
x = np.array(x)
|
||||||
|
if x.ndim >= 2:
|
||||||
|
return x
|
||||||
|
return x.reshape(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(object):
|
||||||
|
def __init__(self, limit, action_shape, observation_shape):
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
|
self.observations0 = RingBuffer(limit, shape=observation_shape)
|
||||||
|
self.actions = RingBuffer(limit, shape=action_shape)
|
||||||
|
self.rewards = RingBuffer(limit, shape=(1,))
|
||||||
|
self.terminals1 = RingBuffer(limit, shape=(1,))
|
||||||
|
self.observations1 = RingBuffer(limit, shape=observation_shape)
|
||||||
|
|
||||||
|
def sample(self, batch_size):
|
||||||
|
# Draw such that we always have a proceeding element.
|
||||||
|
batch_idxs = np.random.random_integers(self.nb_entries - 2, size=batch_size)
|
||||||
|
|
||||||
|
obs0_batch = self.observations0.get_batch(batch_idxs)
|
||||||
|
obs1_batch = self.observations1.get_batch(batch_idxs)
|
||||||
|
action_batch = self.actions.get_batch(batch_idxs)
|
||||||
|
reward_batch = self.rewards.get_batch(batch_idxs)
|
||||||
|
terminal1_batch = self.terminals1.get_batch(batch_idxs)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'obs0': array_min2d(obs0_batch),
|
||||||
|
'obs1': array_min2d(obs1_batch),
|
||||||
|
'rewards': array_min2d(reward_batch),
|
||||||
|
'actions': array_min2d(action_batch),
|
||||||
|
'terminals1': array_min2d(terminal1_batch),
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def append(self, obs0, action, reward, obs1, terminal1, training=True):
|
||||||
|
if not training:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.observations0.append(obs0)
|
||||||
|
self.actions.append(action)
|
||||||
|
self.rewards.append(reward)
|
||||||
|
self.observations1.append(obs1)
|
||||||
|
self.terminals1.append(terminal1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nb_entries(self):
|
||||||
|
return len(self.observations0)
|
77
baselines/ddpg/models.py
Normal file
77
baselines/ddpg/models.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.contrib as tc
|
||||||
|
|
||||||
|
|
||||||
|
class Model(object):
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vars(self):
|
||||||
|
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trainable_vars(self):
|
||||||
|
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def perturbable_vars(self):
|
||||||
|
return [var for var in self.trainable_vars if 'LayerNorm' not in var.name]
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(Model):
|
||||||
|
def __init__(self, nb_actions, name='actor', layer_norm=True):
|
||||||
|
super(Actor, self).__init__(name=name)
|
||||||
|
self.nb_actions = nb_actions
|
||||||
|
self.layer_norm = layer_norm
|
||||||
|
|
||||||
|
def __call__(self, obs, reuse=False):
|
||||||
|
with tf.variable_scope(self.name) as scope:
|
||||||
|
if reuse:
|
||||||
|
scope.reuse_variables()
|
||||||
|
|
||||||
|
x = obs
|
||||||
|
x = tf.layers.dense(x, 64)
|
||||||
|
if self.layer_norm:
|
||||||
|
x = tc.layers.layer_norm(x, center=True, scale=True)
|
||||||
|
x = tf.nn.relu(x)
|
||||||
|
|
||||||
|
x = tf.layers.dense(x, 64)
|
||||||
|
if self.layer_norm:
|
||||||
|
x = tc.layers.layer_norm(x, center=True, scale=True)
|
||||||
|
x = tf.nn.relu(x)
|
||||||
|
|
||||||
|
x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
||||||
|
x = tf.nn.tanh(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(Model):
|
||||||
|
def __init__(self, name='critic', layer_norm=True):
|
||||||
|
super(Critic, self).__init__(name=name)
|
||||||
|
self.layer_norm = layer_norm
|
||||||
|
|
||||||
|
def __call__(self, obs, action, reuse=False):
|
||||||
|
with tf.variable_scope(self.name) as scope:
|
||||||
|
if reuse:
|
||||||
|
scope.reuse_variables()
|
||||||
|
|
||||||
|
x = obs
|
||||||
|
x = tf.layers.dense(x, 64)
|
||||||
|
if self.layer_norm:
|
||||||
|
x = tc.layers.layer_norm(x, center=True, scale=True)
|
||||||
|
x = tf.nn.relu(x)
|
||||||
|
|
||||||
|
x = tf.concat([x, action], axis=-1)
|
||||||
|
x = tf.layers.dense(x, 64)
|
||||||
|
if self.layer_norm:
|
||||||
|
x = tc.layers.layer_norm(x, center=True, scale=True)
|
||||||
|
x = tf.nn.relu(x)
|
||||||
|
|
||||||
|
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
||||||
|
return x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_vars(self):
|
||||||
|
output_vars = [var for var in self.trainable_vars if 'output' in var.name]
|
||||||
|
return output_vars
|
67
baselines/ddpg/noise.py
Normal file
67
baselines/ddpg/noise.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveParamNoiseSpec(object):
|
||||||
|
def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01):
|
||||||
|
self.initial_stddev = initial_stddev
|
||||||
|
self.desired_action_stddev = desired_action_stddev
|
||||||
|
self.adoption_coefficient = adoption_coefficient
|
||||||
|
|
||||||
|
self.current_stddev = initial_stddev
|
||||||
|
|
||||||
|
def adapt(self, distance):
|
||||||
|
if distance > self.desired_action_stddev:
|
||||||
|
# Decrease stddev.
|
||||||
|
self.current_stddev /= self.adoption_coefficient
|
||||||
|
else:
|
||||||
|
# Increase stddev.
|
||||||
|
self.current_stddev *= self.adoption_coefficient
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
stats = {
|
||||||
|
'param_noise_stddev': self.current_stddev,
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})'
|
||||||
|
return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionNoise(object):
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NormalActionNoise(ActionNoise):
|
||||||
|
def __init__(self, mu, sigma):
|
||||||
|
self.mu = mu
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return np.random.normal(self.mu, self.sigma)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'NormalActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma)
|
||||||
|
|
||||||
|
|
||||||
|
# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||||
|
class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
||||||
|
def __init__(self, mu, sigma, theta=.15, dt=1e-2, x0=None):
|
||||||
|
self.theta = theta
|
||||||
|
self.mu = mu
|
||||||
|
self.sigma = sigma
|
||||||
|
self.dt = dt
|
||||||
|
self.x0 = x0
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
|
||||||
|
self.x_prev = x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma)
|
189
baselines/ddpg/training.py
Normal file
189
baselines/ddpg/training.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from baselines.ddpg.ddpg import DDPG
|
||||||
|
from baselines.ddpg.util import mpi_mean, mpi_std, mpi_max, mpi_sum
|
||||||
|
import baselines.common.tf_util as U
|
||||||
|
|
||||||
|
from baselines import logger
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
|
||||||
|
def train(env, nb_epochs, nb_epoch_cycles, render_eval, reward_scale, render, param_noise, actor, critic,
|
||||||
|
normalize_returns, normalize_observations, critic_l2_reg, actor_lr, critic_lr, action_noise, logdir,
|
||||||
|
popart, gamma, clip_norm, nb_train_steps, nb_rollout_steps, nb_eval_steps, batch_size, memory,
|
||||||
|
tau=0.01, eval_env=None, param_noise_adaption_interval=50):
|
||||||
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
|
||||||
|
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
||||||
|
max_action = env.action_space.high
|
||||||
|
logger.info('scaling actions by {} before executing in env'.format(max_action))
|
||||||
|
agent = DDPG(actor, critic, memory, env.observation_space.shape, env.action_space.shape,
|
||||||
|
gamma=gamma, tau=tau, normalize_returns=normalize_returns, normalize_observations=normalize_observations,
|
||||||
|
batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
|
||||||
|
actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
|
||||||
|
reward_scale=reward_scale)
|
||||||
|
logger.info('Using agent with the following configuration:')
|
||||||
|
logger.info(str(agent.__dict__.items()))
|
||||||
|
|
||||||
|
# Set up logging stuff only for a single worker.
|
||||||
|
if rank == 0:
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
else:
|
||||||
|
saver = None
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
episode = 0
|
||||||
|
eval_episode_rewards_history = deque(maxlen=100)
|
||||||
|
episode_rewards_history = deque(maxlen=100)
|
||||||
|
with U.single_threaded_session() as sess:
|
||||||
|
# Prepare everything.
|
||||||
|
agent.initialize(sess)
|
||||||
|
sess.graph.finalize()
|
||||||
|
|
||||||
|
agent.reset()
|
||||||
|
obs = env.reset()
|
||||||
|
if eval_env is not None:
|
||||||
|
eval_obs = eval_env.reset()
|
||||||
|
done = False
|
||||||
|
episode_reward = 0.
|
||||||
|
episode_step = 0
|
||||||
|
episodes = 0
|
||||||
|
t = 0
|
||||||
|
|
||||||
|
epoch = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
epoch_episode_rewards = []
|
||||||
|
epoch_episode_steps = []
|
||||||
|
epoch_episode_eval_rewards = []
|
||||||
|
epoch_episode_eval_steps = []
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
epoch_actions = []
|
||||||
|
epoch_qs = []
|
||||||
|
epoch_episodes = 0
|
||||||
|
for epoch in range(nb_epochs):
|
||||||
|
for cycle in range(nb_epoch_cycles):
|
||||||
|
# Perform rollouts.
|
||||||
|
for t_rollout in range(nb_rollout_steps):
|
||||||
|
# Predict next action.
|
||||||
|
action, q = agent.pi(obs, apply_noise=True, compute_Q=True)
|
||||||
|
assert action.shape == env.action_space.shape
|
||||||
|
|
||||||
|
# Execute next action.
|
||||||
|
if rank == 0 and render:
|
||||||
|
env.render()
|
||||||
|
assert max_action.shape == action.shape
|
||||||
|
new_obs, r, done, info = env.step(max_action * action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
||||||
|
t += 1
|
||||||
|
if rank == 0 and render:
|
||||||
|
env.render()
|
||||||
|
episode_reward += r
|
||||||
|
episode_step += 1
|
||||||
|
|
||||||
|
# Book-keeping.
|
||||||
|
epoch_actions.append(action)
|
||||||
|
epoch_qs.append(q)
|
||||||
|
agent.store_transition(obs, action, r, new_obs, done)
|
||||||
|
obs = new_obs
|
||||||
|
|
||||||
|
if done:
|
||||||
|
# Episode done.
|
||||||
|
epoch_episode_rewards.append(episode_reward)
|
||||||
|
episode_rewards_history.append(episode_reward)
|
||||||
|
epoch_episode_steps.append(episode_step)
|
||||||
|
episode_reward = 0.
|
||||||
|
episode_step = 0
|
||||||
|
epoch_episodes += 1
|
||||||
|
episodes += 1
|
||||||
|
|
||||||
|
agent.reset()
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
# Train.
|
||||||
|
epoch_actor_losses = []
|
||||||
|
epoch_critic_losses = []
|
||||||
|
epoch_adaptive_distances = []
|
||||||
|
for t_train in range(nb_train_steps):
|
||||||
|
# Adapt param noise, if necessary.
|
||||||
|
if memory.nb_entries >= batch_size and t % param_noise_adaption_interval == 0:
|
||||||
|
distance = agent.adapt_param_noise()
|
||||||
|
epoch_adaptive_distances.append(distance)
|
||||||
|
|
||||||
|
cl, al = agent.train()
|
||||||
|
epoch_critic_losses.append(cl)
|
||||||
|
epoch_actor_losses.append(al)
|
||||||
|
agent.update_target_net()
|
||||||
|
|
||||||
|
# Evaluate.
|
||||||
|
eval_episode_rewards = []
|
||||||
|
eval_qs = []
|
||||||
|
if eval_env is not None:
|
||||||
|
eval_episode_reward = 0.
|
||||||
|
for t_rollout in range(nb_eval_steps):
|
||||||
|
eval_action, eval_q = agent.pi(eval_obs, apply_noise=False, compute_Q=True)
|
||||||
|
eval_obs, eval_r, eval_done, eval_info = eval_env.step(max_action * eval_action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
||||||
|
if render_eval:
|
||||||
|
eval_env.render()
|
||||||
|
eval_episode_reward += eval_r
|
||||||
|
|
||||||
|
eval_qs.append(eval_q)
|
||||||
|
if eval_done:
|
||||||
|
eval_obs = eval_env.reset()
|
||||||
|
eval_episode_rewards.append(eval_episode_reward)
|
||||||
|
eval_episode_rewards_history.append(eval_episode_reward)
|
||||||
|
eval_episode_reward = 0.
|
||||||
|
|
||||||
|
# Log stats.
|
||||||
|
epoch_train_duration = time.time() - epoch_start_time
|
||||||
|
duration = time.time() - start_time
|
||||||
|
stats = agent.get_stats()
|
||||||
|
combined_stats = {}
|
||||||
|
for key in sorted(stats.keys()):
|
||||||
|
combined_stats[key] = mpi_mean(stats[key])
|
||||||
|
|
||||||
|
# Rollout statistics.
|
||||||
|
combined_stats['rollout/return'] = mpi_mean(epoch_episode_rewards)
|
||||||
|
combined_stats['rollout/return_history'] = mpi_mean(np.mean(episode_rewards_history))
|
||||||
|
combined_stats['rollout/episode_steps'] = mpi_mean(epoch_episode_steps)
|
||||||
|
combined_stats['rollout/episodes'] = mpi_sum(epoch_episodes)
|
||||||
|
combined_stats['rollout/actions_mean'] = mpi_mean(epoch_actions)
|
||||||
|
combined_stats['rollout/actions_std'] = mpi_std(epoch_actions)
|
||||||
|
combined_stats['rollout/Q_mean'] = mpi_mean(epoch_qs)
|
||||||
|
|
||||||
|
# Train statistics.
|
||||||
|
combined_stats['train/loss_actor'] = mpi_mean(epoch_actor_losses)
|
||||||
|
combined_stats['train/loss_critic'] = mpi_mean(epoch_critic_losses)
|
||||||
|
combined_stats['train/param_noise_distance'] = mpi_mean(epoch_adaptive_distances)
|
||||||
|
|
||||||
|
# Evaluation statistics.
|
||||||
|
if eval_env is not None:
|
||||||
|
combined_stats['eval/return'] = mpi_mean(eval_episode_rewards)
|
||||||
|
combined_stats['eval/return_history'] = mpi_mean(np.mean(eval_episode_rewards_history))
|
||||||
|
combined_stats['eval/Q'] = mpi_mean(eval_qs)
|
||||||
|
combined_stats['eval/episodes'] = mpi_mean(len(eval_episode_rewards))
|
||||||
|
|
||||||
|
# Total statistics.
|
||||||
|
combined_stats['total/duration'] = mpi_mean(duration)
|
||||||
|
combined_stats['total/steps_per_second'] = mpi_mean(float(t) / float(duration))
|
||||||
|
combined_stats['total/episodes'] = mpi_mean(episodes)
|
||||||
|
combined_stats['total/epochs'] = epoch + 1
|
||||||
|
combined_stats['total/steps'] = t
|
||||||
|
|
||||||
|
for key in sorted(combined_stats.keys()):
|
||||||
|
logger.record_tabular(key, combined_stats[key])
|
||||||
|
logger.dump_tabular()
|
||||||
|
logger.info('')
|
||||||
|
|
||||||
|
if rank == 0 and logdir:
|
||||||
|
if hasattr(env, 'get_state'):
|
||||||
|
with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as f:
|
||||||
|
pickle.dump(env.get_state(), f)
|
||||||
|
if eval_env and hasattr(eval_env, 'get_state'):
|
||||||
|
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
|
||||||
|
pickle.dump(eval_env.get_state(), f)
|
||||||
|
|
47
baselines/ddpg/util.py
Normal file
47
baselines/ddpg/util.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from mpi4py import MPI
|
||||||
|
from baselines.common.mpi_moments import mpi_moments
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_var(x, axis=None, keepdims=False):
|
||||||
|
m = tf.reduce_mean(x, axis=axis, keep_dims=True)
|
||||||
|
devs_squared = tf.square(x - m)
|
||||||
|
return tf.reduce_mean(devs_squared, axis=axis, keep_dims=keepdims)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_std(x, axis=None, keepdims=False):
|
||||||
|
return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))
|
||||||
|
|
||||||
|
|
||||||
|
def mpi_mean(value):
|
||||||
|
if value == []:
|
||||||
|
value = [0.]
|
||||||
|
if not isinstance(value, list):
|
||||||
|
value = [value]
|
||||||
|
return mpi_moments(np.array(value))[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
def mpi_std(value):
|
||||||
|
if value == []:
|
||||||
|
value = [0.]
|
||||||
|
if not isinstance(value, list):
|
||||||
|
value = [value]
|
||||||
|
return mpi_moments(np.array(value))[1][0]
|
||||||
|
|
||||||
|
|
||||||
|
def mpi_max(value):
|
||||||
|
global_max = np.zeros(1, dtype='float64')
|
||||||
|
local_max = np.max(value).astype('float64')
|
||||||
|
MPI.COMM_WORLD.Reduce(local_max, global_max, op=MPI.MAX)
|
||||||
|
return global_max[0]
|
||||||
|
|
||||||
|
|
||||||
|
def mpi_sum(value):
|
||||||
|
global_sum = np.zeros(1, dtype='float64')
|
||||||
|
local_sum = np.sum(np.array(value)).astype('float64')
|
||||||
|
MPI.COMM_WORLD.Reduce(local_sum, global_sum, op=MPI.SUM)
|
||||||
|
return global_sum[0]
|
@@ -22,6 +22,32 @@ The functions in this file can are used to create the following functions:
|
|||||||
every element of the batch.
|
every element of the batch.
|
||||||
|
|
||||||
|
|
||||||
|
======= act (in case of parameter noise) ========
|
||||||
|
|
||||||
|
Function to chose an action given an observation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
observation: object
|
||||||
|
Observation that can be feed into the output of make_obs_ph
|
||||||
|
stochastic: bool
|
||||||
|
if set to False all the actions are always deterministic (default False)
|
||||||
|
update_eps_ph: float
|
||||||
|
update epsilon a new value, if negative not update happens
|
||||||
|
(default: no update)
|
||||||
|
reset_ph: bool
|
||||||
|
reset the perturbed policy by sampling a new perturbation
|
||||||
|
update_param_noise_threshold_ph: float
|
||||||
|
the desired threshold for the difference between non-perturbed and perturbed policy
|
||||||
|
update_param_noise_scale_ph: bool
|
||||||
|
whether or not to update the scale of the noise for the next time it is re-perturbed
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be performed for
|
||||||
|
every element of the batch.
|
||||||
|
|
||||||
|
|
||||||
======= train =======
|
======= train =======
|
||||||
|
|
||||||
Function that takes a transition (s,a,r,s') and optimizes Bellman equation's error:
|
Function that takes a transition (s,a,r,s') and optimizes Bellman equation's error:
|
||||||
@@ -71,6 +97,21 @@ import tensorflow as tf
|
|||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
|
|
||||||
|
|
||||||
|
def default_param_noise_filter(var):
|
||||||
|
if var not in tf.trainable_variables():
|
||||||
|
# We never perturb non-trainable vars.
|
||||||
|
return False
|
||||||
|
if "fully_connected" in var.name:
|
||||||
|
# We perturb fully-connected layers.
|
||||||
|
return True
|
||||||
|
|
||||||
|
# The remaining layers are likely conv or layer norm layers, which we do not wish to
|
||||||
|
# perturb (in the former case because they only extract features, in the latter case because
|
||||||
|
# we use them for normalization purposes). If you change your network, you will likely want
|
||||||
|
# to re-consider which layers to perturb and which to keep untouched.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
|
def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
|
||||||
"""Creates the act function:
|
"""Creates the act function:
|
||||||
|
|
||||||
@@ -118,7 +159,6 @@ def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
|
|||||||
|
|
||||||
output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
|
output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
|
||||||
update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
|
update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
|
||||||
|
|
||||||
act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph],
|
act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph],
|
||||||
outputs=output_actions,
|
outputs=output_actions,
|
||||||
givens={update_eps_ph: -1.0, stochastic_ph: True},
|
givens={update_eps_ph: -1.0, stochastic_ph: True},
|
||||||
@@ -126,7 +166,121 @@ def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
|
|||||||
return act
|
return act
|
||||||
|
|
||||||
|
|
||||||
def build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None, gamma=1.0, double_q=True, scope="deepq", reuse=None):
|
def build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None, param_noise_filter_func=None):
|
||||||
|
"""Creates the act function with support for parameter space noise exploration (https://arxiv.org/abs/1706.01905):
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
make_obs_ph: str -> tf.placeholder or TfInput
|
||||||
|
a function that take a name and creates a placeholder of input with that name
|
||||||
|
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||||
|
the model that takes the following inputs:
|
||||||
|
observation_in: object
|
||||||
|
the output of observation placeholder
|
||||||
|
num_actions: int
|
||||||
|
number of actions
|
||||||
|
scope: str
|
||||||
|
reuse: bool
|
||||||
|
should be passed to outer variable scope
|
||||||
|
and returns a tensor of shape (batch_size, num_actions) with values of every action.
|
||||||
|
num_actions: int
|
||||||
|
number of actions.
|
||||||
|
scope: str or VariableScope
|
||||||
|
optional scope for variable_scope.
|
||||||
|
reuse: bool or None
|
||||||
|
whether or not the variables should be reused. To be able to reuse the scope must be given.
|
||||||
|
param_noise_filter_func: tf.Variable -> bool
|
||||||
|
function that decides whether or not a variable should be perturbed. Only applicable
|
||||||
|
if param_noise is True. If set to None, default_param_noise_filter is used by default.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
act: (tf.Variable, bool, float, bool, float, bool) -> tf.Variable
|
||||||
|
function to select and action given observation.
|
||||||
|
` See the top of the file for details.
|
||||||
|
"""
|
||||||
|
if param_noise_filter_func is None:
|
||||||
|
param_noise_filter_func = default_param_noise_filter
|
||||||
|
|
||||||
|
with tf.variable_scope(scope, reuse=reuse):
|
||||||
|
observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
|
||||||
|
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
|
||||||
|
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
|
||||||
|
update_param_noise_threshold_ph = tf.placeholder(tf.float32, (), name="update_param_noise_threshold")
|
||||||
|
update_param_noise_scale_ph = tf.placeholder(tf.bool, (), name="update_param_noise_scale")
|
||||||
|
reset_ph = tf.placeholder(tf.bool, (), name="reset")
|
||||||
|
|
||||||
|
eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0))
|
||||||
|
param_noise_scale = tf.get_variable("param_noise_scale", (), initializer=tf.constant_initializer(0.01), trainable=False)
|
||||||
|
param_noise_threshold = tf.get_variable("param_noise_threshold", (), initializer=tf.constant_initializer(0.05), trainable=False)
|
||||||
|
|
||||||
|
# Unmodified Q.
|
||||||
|
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
|
||||||
|
|
||||||
|
# Perturbable Q used for the actual rollout.
|
||||||
|
q_values_perturbed = q_func(observations_ph.get(), num_actions, scope="perturbed_q_func")
|
||||||
|
# We have to wrap this code into a function due to the way tf.cond() works. See
|
||||||
|
# https://stackoverflow.com/questions/37063952/confused-by-the-behavior-of-tf-cond for
|
||||||
|
# a more detailed discussion.
|
||||||
|
def perturb_vars(original_scope, perturbed_scope):
|
||||||
|
all_vars = U.scope_vars(U.absolute_scope_name("q_func"))
|
||||||
|
all_perturbed_vars = U.scope_vars(U.absolute_scope_name("perturbed_q_func"))
|
||||||
|
assert len(all_vars) == len(all_perturbed_vars)
|
||||||
|
perturb_ops = []
|
||||||
|
for var, perturbed_var in zip(all_vars, all_perturbed_vars):
|
||||||
|
if param_noise_filter_func(perturbed_var):
|
||||||
|
# Perturb this variable.
|
||||||
|
op = tf.assign(perturbed_var, var + tf.random_normal(shape=tf.shape(var), mean=0., stddev=param_noise_scale))
|
||||||
|
else:
|
||||||
|
# Do not perturb, just assign.
|
||||||
|
op = tf.assign(perturbed_var, var)
|
||||||
|
perturb_ops.append(op)
|
||||||
|
assert len(perturb_ops) == len(all_vars)
|
||||||
|
return tf.group(*perturb_ops)
|
||||||
|
|
||||||
|
# Set up functionality to re-compute `param_noise_scale`. This perturbs yet another copy
|
||||||
|
# of the network and measures the effect of that perturbation in action space. If the perturbation
|
||||||
|
# is too big, reduce scale of perturbation, otherwise increase.
|
||||||
|
q_values_adaptive = q_func(observations_ph.get(), num_actions, scope="adaptive_q_func")
|
||||||
|
perturb_for_adaption = perturb_vars(original_scope="q_func", perturbed_scope="adaptive_q_func")
|
||||||
|
kl = tf.reduce_sum(tf.nn.softmax(q_values) * (tf.log(tf.nn.softmax(q_values)) - tf.log(tf.nn.softmax(q_values_adaptive))), axis=-1)
|
||||||
|
mean_kl = tf.reduce_mean(kl)
|
||||||
|
def update_scale():
|
||||||
|
with tf.control_dependencies([perturb_for_adaption]):
|
||||||
|
update_scale_expr = tf.cond(mean_kl < param_noise_threshold,
|
||||||
|
lambda: param_noise_scale.assign(param_noise_scale * 1.01),
|
||||||
|
lambda: param_noise_scale.assign(param_noise_scale / 1.01),
|
||||||
|
)
|
||||||
|
return update_scale_expr
|
||||||
|
|
||||||
|
# Functionality to update the threshold for parameter space noise.
|
||||||
|
update_param_noise_threshold_expr = param_noise_threshold.assign(tf.cond(update_param_noise_threshold_ph >= 0,
|
||||||
|
lambda: update_param_noise_threshold_ph, lambda: param_noise_threshold))
|
||||||
|
|
||||||
|
# Put everything together.
|
||||||
|
deterministic_actions = tf.argmax(q_values_perturbed, axis=1)
|
||||||
|
batch_size = tf.shape(observations_ph.get())[0]
|
||||||
|
random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
|
||||||
|
chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
|
||||||
|
stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)
|
||||||
|
|
||||||
|
output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
|
||||||
|
update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
|
||||||
|
updates = [
|
||||||
|
update_eps_expr,
|
||||||
|
tf.cond(reset_ph, lambda: perturb_vars(original_scope="q_func", perturbed_scope="perturbed_q_func"), lambda: tf.group(*[])),
|
||||||
|
tf.cond(update_param_noise_scale_ph, lambda: update_scale(), lambda: tf.Variable(0., trainable=False)),
|
||||||
|
update_param_noise_threshold_expr,
|
||||||
|
]
|
||||||
|
act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph, reset_ph, update_param_noise_threshold_ph, update_param_noise_scale_ph],
|
||||||
|
outputs=output_actions,
|
||||||
|
givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False},
|
||||||
|
updates=updates)
|
||||||
|
return act
|
||||||
|
|
||||||
|
|
||||||
|
def build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None, gamma=1.0,
|
||||||
|
double_q=True, scope="deepq", reuse=None, param_noise=False, param_noise_filter_func=None):
|
||||||
"""Creates the train function:
|
"""Creates the train function:
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -160,6 +314,11 @@ def build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=
|
|||||||
optional scope for variable_scope.
|
optional scope for variable_scope.
|
||||||
reuse: bool or None
|
reuse: bool or None
|
||||||
whether or not the variables should be reused. To be able to reuse the scope must be given.
|
whether or not the variables should be reused. To be able to reuse the scope must be given.
|
||||||
|
param_noise: bool
|
||||||
|
whether or not to use parameter space noise (https://arxiv.org/abs/1706.01905)
|
||||||
|
param_noise_filter_func: tf.Variable -> bool
|
||||||
|
function that decides whether or not a variable should be perturbed. Only applicable
|
||||||
|
if param_noise is True. If set to None, default_param_noise_filter is used by default.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -175,7 +334,11 @@ def build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=
|
|||||||
debug: {str: function}
|
debug: {str: function}
|
||||||
a bunch of functions to print debug data like q_values.
|
a bunch of functions to print debug data like q_values.
|
||||||
"""
|
"""
|
||||||
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse)
|
if param_noise:
|
||||||
|
act_f = build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse,
|
||||||
|
param_noise_filter_func=param_noise_filter_func)
|
||||||
|
else:
|
||||||
|
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse)
|
||||||
|
|
||||||
with tf.variable_scope(scope, reuse=reuse):
|
with tf.variable_scope(scope, reuse=reuse):
|
||||||
# set up placeholders
|
# set up placeholders
|
||||||
@@ -213,6 +376,7 @@ def build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=
|
|||||||
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||||
errors = U.huber_loss(td_error)
|
errors = U.huber_loss(td_error)
|
||||||
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
|
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
|
||||||
|
|
||||||
# compute optimization op (potentially with gradient clipping)
|
# compute optimization op (potentially with gradient clipping)
|
||||||
if grad_norm_clipping is not None:
|
if grad_norm_clipping is not None:
|
||||||
optimize_expr = U.minimize_and_clip(optimizer,
|
optimize_expr = U.minimize_and_clip(optimizer,
|
||||||
|
@@ -2,7 +2,14 @@ import tensorflow as tf
|
|||||||
import tensorflow.contrib.layers as layers
|
import tensorflow.contrib.layers as layers
|
||||||
|
|
||||||
|
|
||||||
def model(img_in, num_actions, scope, reuse=False):
|
def layer_norm_fn(x, relu=True):
|
||||||
|
x = layers.layer_norm(x, scale=True, center=True)
|
||||||
|
if relu:
|
||||||
|
x = tf.nn.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def model(img_in, num_actions, scope, reuse=False, layer_norm=False):
|
||||||
"""As described in https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf"""
|
"""As described in https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf"""
|
||||||
with tf.variable_scope(scope, reuse=reuse):
|
with tf.variable_scope(scope, reuse=reuse):
|
||||||
out = img_in
|
out = img_in
|
||||||
@@ -11,16 +18,19 @@ def model(img_in, num_actions, scope, reuse=False):
|
|||||||
out = layers.convolution2d(out, num_outputs=32, kernel_size=8, stride=4, activation_fn=tf.nn.relu)
|
out = layers.convolution2d(out, num_outputs=32, kernel_size=8, stride=4, activation_fn=tf.nn.relu)
|
||||||
out = layers.convolution2d(out, num_outputs=64, kernel_size=4, stride=2, activation_fn=tf.nn.relu)
|
out = layers.convolution2d(out, num_outputs=64, kernel_size=4, stride=2, activation_fn=tf.nn.relu)
|
||||||
out = layers.convolution2d(out, num_outputs=64, kernel_size=3, stride=1, activation_fn=tf.nn.relu)
|
out = layers.convolution2d(out, num_outputs=64, kernel_size=3, stride=1, activation_fn=tf.nn.relu)
|
||||||
out = layers.flatten(out)
|
conv_out = layers.flatten(out)
|
||||||
|
|
||||||
with tf.variable_scope("action_value"):
|
with tf.variable_scope("action_value"):
|
||||||
out = layers.fully_connected(out, num_outputs=512, activation_fn=tf.nn.relu)
|
value_out = layers.fully_connected(conv_out, num_outputs=512, activation_fn=None)
|
||||||
out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
|
if layer_norm:
|
||||||
|
value_out = layer_norm_fn(value_out, relu=True)
|
||||||
return out
|
else:
|
||||||
|
value_out = tf.nn.relu(value_out)
|
||||||
|
value_out = layers.fully_connected(value_out, num_outputs=num_actions, activation_fn=None)
|
||||||
|
return value_out
|
||||||
|
|
||||||
|
|
||||||
def dueling_model(img_in, num_actions, scope, reuse=False):
|
def dueling_model(img_in, num_actions, scope, reuse=False, layer_norm=False):
|
||||||
"""As described in https://arxiv.org/abs/1511.06581"""
|
"""As described in https://arxiv.org/abs/1511.06581"""
|
||||||
with tf.variable_scope(scope, reuse=reuse):
|
with tf.variable_scope(scope, reuse=reuse):
|
||||||
out = img_in
|
out = img_in
|
||||||
@@ -29,15 +39,22 @@ def dueling_model(img_in, num_actions, scope, reuse=False):
|
|||||||
out = layers.convolution2d(out, num_outputs=32, kernel_size=8, stride=4, activation_fn=tf.nn.relu)
|
out = layers.convolution2d(out, num_outputs=32, kernel_size=8, stride=4, activation_fn=tf.nn.relu)
|
||||||
out = layers.convolution2d(out, num_outputs=64, kernel_size=4, stride=2, activation_fn=tf.nn.relu)
|
out = layers.convolution2d(out, num_outputs=64, kernel_size=4, stride=2, activation_fn=tf.nn.relu)
|
||||||
out = layers.convolution2d(out, num_outputs=64, kernel_size=3, stride=1, activation_fn=tf.nn.relu)
|
out = layers.convolution2d(out, num_outputs=64, kernel_size=3, stride=1, activation_fn=tf.nn.relu)
|
||||||
out = layers.flatten(out)
|
conv_out = layers.flatten(out)
|
||||||
|
|
||||||
with tf.variable_scope("state_value"):
|
with tf.variable_scope("state_value"):
|
||||||
state_hidden = layers.fully_connected(out, num_outputs=512, activation_fn=tf.nn.relu)
|
state_hidden = layers.fully_connected(conv_out, num_outputs=512, activation_fn=None)
|
||||||
|
if layer_norm:
|
||||||
|
state_hidden = layer_norm_fn(state_hidden, relu=True)
|
||||||
|
else:
|
||||||
|
state_hidden = tf.nn.relu(state_hidden)
|
||||||
state_score = layers.fully_connected(state_hidden, num_outputs=1, activation_fn=None)
|
state_score = layers.fully_connected(state_hidden, num_outputs=1, activation_fn=None)
|
||||||
with tf.variable_scope("action_value"):
|
with tf.variable_scope("action_value"):
|
||||||
actions_hidden = layers.fully_connected(out, num_outputs=512, activation_fn=tf.nn.relu)
|
actions_hidden = layers.fully_connected(conv_out, num_outputs=512, activation_fn=None)
|
||||||
|
if layer_norm:
|
||||||
|
actions_hidden = layer_norm_fn(actions_hidden, relu=True)
|
||||||
|
else:
|
||||||
|
actions_hidden = tf.nn.relu(actions_hidden)
|
||||||
action_scores = layers.fully_connected(actions_hidden, num_outputs=num_actions, activation_fn=None)
|
action_scores = layers.fully_connected(actions_hidden, num_outputs=num_actions, activation_fn=None)
|
||||||
action_scores_mean = tf.reduce_mean(action_scores, 1)
|
action_scores_mean = tf.reduce_mean(action_scores, 1)
|
||||||
action_scores = action_scores - tf.expand_dims(action_scores_mean, 1)
|
action_scores = action_scores - tf.expand_dims(action_scores_mean, 1)
|
||||||
|
|
||||||
return state_score + action_scores
|
return state_score + action_scores
|
||||||
|
@@ -5,6 +5,7 @@ import os
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
|
|
||||||
@@ -40,10 +41,16 @@ def parse_args():
|
|||||||
parser.add_argument("--batch-size", type=int, default=32, help="number of transitions to optimize at the same time")
|
parser.add_argument("--batch-size", type=int, default=32, help="number of transitions to optimize at the same time")
|
||||||
parser.add_argument("--learning-freq", type=int, default=4, help="number of iterations between every optimization step")
|
parser.add_argument("--learning-freq", type=int, default=4, help="number of iterations between every optimization step")
|
||||||
parser.add_argument("--target-update-freq", type=int, default=40000, help="number of iterations between every target network update")
|
parser.add_argument("--target-update-freq", type=int, default=40000, help="number of iterations between every target network update")
|
||||||
|
parser.add_argument("--param-noise-update-freq", type=int, default=50, help="number of iterations between every re-scaling of the parameter noise")
|
||||||
|
parser.add_argument("--param-noise-reset-freq", type=int, default=10000, help="maximum number of steps to take per episode before re-perturbing the exploration policy")
|
||||||
|
parser.add_argument("--param-noise-threshold", type=float, default=0.05, help="the desired KL divergence between perturbed and non-perturbed policy. set to < 0 to use a KL divergence relative to the eps-greedy exploration")
|
||||||
# Bells and whistles
|
# Bells and whistles
|
||||||
boolean_flag(parser, "double-q", default=True, help="whether or not to use double q learning")
|
boolean_flag(parser, "double-q", default=True, help="whether or not to use double q learning")
|
||||||
boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
|
boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
|
||||||
boolean_flag(parser, "prioritized", default=False, help="whether or not to use prioritized replay buffer")
|
boolean_flag(parser, "prioritized", default=False, help="whether or not to use prioritized replay buffer")
|
||||||
|
boolean_flag(parser, "param-noise", default=False, help="whether or not to use parameter space noise for exploration")
|
||||||
|
boolean_flag(parser, "layer-norm", default=False, help="whether or not to use layer norm (should be True if param_noise is used)")
|
||||||
|
boolean_flag(parser, "gym-monitor", default=False, help="whether or not to use a OpenAI Gym monitor (results in slower training due to video recording)")
|
||||||
parser.add_argument("--prioritized-alpha", type=float, default=0.6, help="alpha parameter for prioritized replay buffer")
|
parser.add_argument("--prioritized-alpha", type=float, default=0.6, help="alpha parameter for prioritized replay buffer")
|
||||||
parser.add_argument("--prioritized-beta0", type=float, default=0.4, help="initial value of beta parameters for prioritized replay")
|
parser.add_argument("--prioritized-beta0", type=float, default=0.4, help="initial value of beta parameters for prioritized replay")
|
||||||
parser.add_argument("--prioritized-eps", type=float, default=1e-6, help="eps parameter for prioritized replay buffer")
|
parser.add_argument("--prioritized-eps", type=float, default=1e-6, help="eps parameter for prioritized replay buffer")
|
||||||
@@ -104,8 +111,11 @@ def maybe_load_model(savedir, container):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Parse savedir and azure container.
|
# Parse savedir and azure container.
|
||||||
savedir = args.save_dir
|
savedir = args.save_dir
|
||||||
|
if savedir is None:
|
||||||
|
savedir = os.getenv('OPENAI_LOGDIR', None)
|
||||||
if args.save_azure_container is not None:
|
if args.save_azure_container is not None:
|
||||||
account_name, account_key, container_name = args.save_azure_container.split(":")
|
account_name, account_key, container_name = args.save_azure_container.split(":")
|
||||||
container = Container(account_name=account_name,
|
container = Container(account_name=account_name,
|
||||||
@@ -123,16 +133,27 @@ if __name__ == '__main__':
|
|||||||
set_global_seeds(args.seed)
|
set_global_seeds(args.seed)
|
||||||
env.unwrapped.seed(args.seed)
|
env.unwrapped.seed(args.seed)
|
||||||
|
|
||||||
|
if args.gym_monitor and savedir:
|
||||||
|
env = gym.wrappers.Monitor(env, os.path.join(savedir, 'gym_monitor'), force=True)
|
||||||
|
|
||||||
|
if savedir:
|
||||||
|
with open(os.path.join(savedir, 'args.json'), 'w') as f:
|
||||||
|
json.dump(vars(args), f)
|
||||||
|
|
||||||
with U.make_session(4) as sess:
|
with U.make_session(4) as sess:
|
||||||
# Create training graph and replay buffer
|
# Create training graph and replay buffer
|
||||||
|
def model_wrapper(img_in, num_actions, scope, **kwargs):
|
||||||
|
actual_model = dueling_model if args.dueling else model
|
||||||
|
return actual_model(img_in, num_actions, scope, layer_norm=args.layer_norm, **kwargs)
|
||||||
act, train, update_target, debug = deepq.build_train(
|
act, train, update_target, debug = deepq.build_train(
|
||||||
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
|
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
|
||||||
q_func=dueling_model if args.dueling else model,
|
q_func=model_wrapper,
|
||||||
num_actions=env.action_space.n,
|
num_actions=env.action_space.n,
|
||||||
optimizer=tf.train.AdamOptimizer(learning_rate=args.lr, epsilon=1e-4),
|
optimizer=tf.train.AdamOptimizer(learning_rate=args.lr, epsilon=1e-4),
|
||||||
gamma=0.99,
|
gamma=0.99,
|
||||||
grad_norm_clipping=10,
|
grad_norm_clipping=10,
|
||||||
double_q=args.double_q
|
double_q=args.double_q,
|
||||||
|
param_noise=args.param_noise
|
||||||
)
|
)
|
||||||
|
|
||||||
approximate_num_iters = args.num_steps / 4
|
approximate_num_iters = args.num_steps / 4
|
||||||
@@ -162,17 +183,46 @@ if __name__ == '__main__':
|
|||||||
steps_per_iter = RunningAvg(0.999)
|
steps_per_iter = RunningAvg(0.999)
|
||||||
iteration_time_est = RunningAvg(0.999)
|
iteration_time_est = RunningAvg(0.999)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
num_iters_since_reset = 0
|
||||||
|
reset = True
|
||||||
|
|
||||||
# Main trianing loop
|
# Main trianing loop
|
||||||
while True:
|
while True:
|
||||||
num_iters += 1
|
num_iters += 1
|
||||||
|
num_iters_since_reset += 1
|
||||||
|
|
||||||
# Take action and store transition in the replay buffer.
|
# Take action and store transition in the replay buffer.
|
||||||
action = act(np.array(obs)[None], update_eps=exploration.value(num_iters))[0]
|
kwargs = {}
|
||||||
|
if not args.param_noise:
|
||||||
|
update_eps = exploration.value(num_iters)
|
||||||
|
update_param_noise_threshold = 0.
|
||||||
|
else:
|
||||||
|
if args.param_noise_reset_freq > 0 and num_iters_since_reset > args.param_noise_reset_freq:
|
||||||
|
# Reset param noise policy since we have exceeded the maximum number of steps without a reset.
|
||||||
|
reset = True
|
||||||
|
|
||||||
|
update_eps = 0.01 # ensures that we cannot get stuck completely
|
||||||
|
if args.param_noise_threshold >= 0.:
|
||||||
|
update_param_noise_threshold = args.param_noise_threshold
|
||||||
|
else:
|
||||||
|
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
|
||||||
|
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
|
||||||
|
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
|
||||||
|
# for detailed explanation.
|
||||||
|
update_param_noise_threshold = -np.log(1. - exploration.value(num_iters) + exploration.value(num_iters) / float(env.action_space.n))
|
||||||
|
kwargs['reset'] = reset
|
||||||
|
kwargs['update_param_noise_threshold'] = update_param_noise_threshold
|
||||||
|
kwargs['update_param_noise_scale'] = (num_iters % args.param_noise_update_freq == 0)
|
||||||
|
|
||||||
|
action = act(np.array(obs)[None], update_eps=update_eps, **kwargs)[0]
|
||||||
|
reset = False
|
||||||
new_obs, rew, done, info = env.step(action)
|
new_obs, rew, done, info = env.step(action)
|
||||||
replay_buffer.add(obs, action, rew, new_obs, float(done))
|
replay_buffer.add(obs, action, rew, new_obs, float(done))
|
||||||
obs = new_obs
|
obs = new_obs
|
||||||
if done:
|
if done:
|
||||||
|
num_iters_since_reset = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
reset = True
|
||||||
|
|
||||||
if (num_iters > max(5 * args.batch_size, args.replay_buffer_size // 20) and
|
if (num_iters > max(5 * args.batch_size, args.replay_buffer_size // 20) and
|
||||||
num_iters % args.learning_freq == 0):
|
num_iters % args.learning_freq == 0):
|
||||||
@@ -203,7 +253,7 @@ if __name__ == '__main__':
|
|||||||
maybe_save_model(savedir, container, {
|
maybe_save_model(savedir, container, {
|
||||||
'replay_buffer': replay_buffer,
|
'replay_buffer': replay_buffer,
|
||||||
'num_iters': num_iters,
|
'num_iters': num_iters,
|
||||||
'monitor_state': monitored_env.get_state()
|
'monitor_state': monitored_env.get_state(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if info["steps"] > args.num_steps:
|
if info["steps"] > args.num_steps:
|
||||||
|
21
baselines/deepq/experiments/enjoy_mountaincar.py
Normal file
21
baselines/deepq/experiments/enjoy_mountaincar.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import gym
|
||||||
|
|
||||||
|
from baselines import deepq
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
env = gym.make("MountainCar-v0")
|
||||||
|
act = deepq.load("mountaincar_model.pkl")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
obs, done = env.reset(), False
|
||||||
|
episode_rew = 0
|
||||||
|
while not done:
|
||||||
|
env.render()
|
||||||
|
obs, rew, done, _ = env.step(act(obs[None])[0])
|
||||||
|
episode_rew += rew
|
||||||
|
print("Episode reward", episode_rew)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
26
baselines/deepq/experiments/train_mountaincar.py
Normal file
26
baselines/deepq/experiments/train_mountaincar.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import gym
|
||||||
|
|
||||||
|
from baselines import deepq
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
env = gym.make("MountainCar-v0")
|
||||||
|
# Enabling layer_norm here is import for parameter space noise!
|
||||||
|
model = deepq.models.mlp([64], layer_norm=True)
|
||||||
|
act = deepq.learn(
|
||||||
|
env,
|
||||||
|
q_func=model,
|
||||||
|
lr=1e-3,
|
||||||
|
max_timesteps=100000,
|
||||||
|
buffer_size=50000,
|
||||||
|
exploration_fraction=0.1,
|
||||||
|
exploration_final_eps=0.1,
|
||||||
|
print_freq=10,
|
||||||
|
param_noise=True
|
||||||
|
)
|
||||||
|
print("Saving model to mountaincar_model.pkl")
|
||||||
|
act.save("mountaincar_model.pkl")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@@ -2,16 +2,19 @@ import tensorflow as tf
|
|||||||
import tensorflow.contrib.layers as layers
|
import tensorflow.contrib.layers as layers
|
||||||
|
|
||||||
|
|
||||||
def _mlp(hiddens, inpt, num_actions, scope, reuse=False):
|
def _mlp(hiddens, inpt, num_actions, scope, reuse=False, layer_norm=False):
|
||||||
with tf.variable_scope(scope, reuse=reuse):
|
with tf.variable_scope(scope, reuse=reuse):
|
||||||
out = inpt
|
out = inpt
|
||||||
for hidden in hiddens:
|
for hidden in hiddens:
|
||||||
out = layers.fully_connected(out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
out = layers.fully_connected(out, num_outputs=hidden, activation_fn=None)
|
||||||
out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
|
if layer_norm:
|
||||||
return out
|
out = layers.layer_norm(out, center=True, scale=True)
|
||||||
|
out = tf.nn.relu(out)
|
||||||
|
q_out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
|
||||||
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
def mlp(hiddens=[]):
|
def mlp(hiddens=[], layer_norm=False):
|
||||||
"""This model takes as input an observation and returns values of all actions.
|
"""This model takes as input an observation and returns values of all actions.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -24,10 +27,10 @@ def mlp(hiddens=[]):
|
|||||||
q_func: function
|
q_func: function
|
||||||
q_function for DQN algorithm.
|
q_function for DQN algorithm.
|
||||||
"""
|
"""
|
||||||
return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs)
|
return lambda *args, **kwargs: _mlp(hiddens, layer_norm=layer_norm, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _cnn_to_mlp(convs, hiddens, dueling, inpt, num_actions, scope, reuse=False):
|
def _cnn_to_mlp(convs, hiddens, dueling, inpt, num_actions, scope, reuse=False, layer_norm=False):
|
||||||
with tf.variable_scope(scope, reuse=reuse):
|
with tf.variable_scope(scope, reuse=reuse):
|
||||||
out = inpt
|
out = inpt
|
||||||
with tf.variable_scope("convnet"):
|
with tf.variable_scope("convnet"):
|
||||||
@@ -37,28 +40,34 @@ def _cnn_to_mlp(convs, hiddens, dueling, inpt, num_actions, scope, reuse=False):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
activation_fn=tf.nn.relu)
|
activation_fn=tf.nn.relu)
|
||||||
out = layers.flatten(out)
|
conv_out = layers.flatten(out)
|
||||||
with tf.variable_scope("action_value"):
|
with tf.variable_scope("action_value"):
|
||||||
action_out = out
|
action_out = conv_out
|
||||||
for hidden in hiddens:
|
for hidden in hiddens:
|
||||||
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
|
||||||
|
if layer_norm:
|
||||||
|
action_out = layers.layer_norm(action_out, center=True, scale=True)
|
||||||
|
action_out = tf.nn.relu(action_out)
|
||||||
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
|
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
|
||||||
|
|
||||||
if dueling:
|
if dueling:
|
||||||
with tf.variable_scope("state_value"):
|
with tf.variable_scope("state_value"):
|
||||||
state_out = out
|
state_out = conv_out
|
||||||
for hidden in hiddens:
|
for hidden in hiddens:
|
||||||
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
|
||||||
|
if layer_norm:
|
||||||
|
state_out = layers.layer_norm(state_out, center=True, scale=True)
|
||||||
|
state_out = tf.nn.relu(state_out)
|
||||||
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
|
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
|
||||||
action_scores_mean = tf.reduce_mean(action_scores, 1)
|
action_scores_mean = tf.reduce_mean(action_scores, 1)
|
||||||
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
|
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
|
||||||
return state_score + action_scores_centered
|
q_out = state_score + action_scores_centered
|
||||||
else:
|
else:
|
||||||
return action_scores
|
q_out = action_scores
|
||||||
return out
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
def cnn_to_mlp(convs, hiddens, dueling=False):
|
def cnn_to_mlp(convs, hiddens, dueling=False, layer_norm=False):
|
||||||
"""This model takes as input an observation and returns values of all actions.
|
"""This model takes as input an observation and returns values of all actions.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -78,5 +87,5 @@ def cnn_to_mlp(convs, hiddens, dueling=False):
|
|||||||
q_function for DQN algorithm.
|
q_function for DQN algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return lambda *args, **kwargs: _cnn_to_mlp(convs, hiddens, dueling, *args, **kwargs)
|
return lambda *args, **kwargs: _cnn_to_mlp(convs, hiddens, dueling, layer_norm=layer_norm, *args, **kwargs)
|
||||||
|
|
||||||
|
@@ -94,6 +94,8 @@ def learn(env,
|
|||||||
prioritized_replay_beta_iters=None,
|
prioritized_replay_beta_iters=None,
|
||||||
prioritized_replay_eps=1e-6,
|
prioritized_replay_eps=1e-6,
|
||||||
num_cpu=16,
|
num_cpu=16,
|
||||||
|
param_noise=False,
|
||||||
|
param_noise_threshold=0.05,
|
||||||
callback=None):
|
callback=None):
|
||||||
"""Train a deepq model.
|
"""Train a deepq model.
|
||||||
|
|
||||||
@@ -176,13 +178,15 @@ def learn(env,
|
|||||||
num_actions=env.action_space.n,
|
num_actions=env.action_space.n,
|
||||||
optimizer=tf.train.AdamOptimizer(learning_rate=lr),
|
optimizer=tf.train.AdamOptimizer(learning_rate=lr),
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
grad_norm_clipping=10
|
grad_norm_clipping=10,
|
||||||
|
param_noise=param_noise
|
||||||
)
|
)
|
||||||
act_params = {
|
act_params = {
|
||||||
'make_obs_ph': make_obs_ph,
|
'make_obs_ph': make_obs_ph,
|
||||||
'q_func': q_func,
|
'q_func': q_func,
|
||||||
'num_actions': env.action_space.n,
|
'num_actions': env.action_space.n,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create the replay buffer
|
# Create the replay buffer
|
||||||
if prioritized_replay:
|
if prioritized_replay:
|
||||||
replay_buffer = PrioritizedReplayBuffer(buffer_size, alpha=prioritized_replay_alpha)
|
replay_buffer = PrioritizedReplayBuffer(buffer_size, alpha=prioritized_replay_alpha)
|
||||||
@@ -206,6 +210,7 @@ def learn(env,
|
|||||||
episode_rewards = [0.0]
|
episode_rewards = [0.0]
|
||||||
saved_mean_reward = None
|
saved_mean_reward = None
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
reset = True
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
model_saved = False
|
model_saved = False
|
||||||
model_file = os.path.join(td, "model")
|
model_file = os.path.join(td, "model")
|
||||||
@@ -214,7 +219,25 @@ def learn(env,
|
|||||||
if callback(locals(), globals()):
|
if callback(locals(), globals()):
|
||||||
break
|
break
|
||||||
# Take action and update exploration to the newest value
|
# Take action and update exploration to the newest value
|
||||||
action = act(np.array(obs)[None], update_eps=exploration.value(t))[0]
|
kwargs = {}
|
||||||
|
if not param_noise:
|
||||||
|
update_eps = exploration.value(t)
|
||||||
|
update_param_noise_threshold = 0.
|
||||||
|
else:
|
||||||
|
update_eps = 0.
|
||||||
|
if param_noise_threshold >= 0.:
|
||||||
|
update_param_noise_threshold = param_noise_threshold
|
||||||
|
else:
|
||||||
|
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
|
||||||
|
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
|
||||||
|
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
|
||||||
|
# for detailed explanation.
|
||||||
|
update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(env.action_space.n))
|
||||||
|
kwargs['reset'] = reset
|
||||||
|
kwargs['update_param_noise_threshold'] = update_param_noise_threshold
|
||||||
|
kwargs['update_param_noise_scale'] = True
|
||||||
|
action = act(np.array(obs)[None], update_eps=update_eps, **kwargs)[0]
|
||||||
|
reset = False
|
||||||
new_obs, rew, done, _ = env.step(action)
|
new_obs, rew, done, _ = env.step(action)
|
||||||
# Store transition in the replay buffer.
|
# Store transition in the replay buffer.
|
||||||
replay_buffer.add(obs, action, rew, new_obs, float(done))
|
replay_buffer.add(obs, action, rew, new_obs, float(done))
|
||||||
@@ -224,6 +247,7 @@ def learn(env,
|
|||||||
if done:
|
if done:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
episode_rewards.append(0.0)
|
episode_rewards.append(0.0)
|
||||||
|
reset = True
|
||||||
|
|
||||||
if t > learning_starts and t % train_freq == 0:
|
if t > learning_starts and t % train_freq == 0:
|
||||||
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
|
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
|
||||||
|
@@ -17,7 +17,7 @@ import time
|
|||||||
import datetime
|
import datetime
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json']
|
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'tensorboard']
|
||||||
|
|
||||||
DEBUG = 10
|
DEBUG = 10
|
||||||
INFO = 20
|
INFO = 20
|
||||||
|
1
setup.py
1
setup.py
@@ -19,6 +19,7 @@ setup(name='baselines',
|
|||||||
'tensorflow >= 1.0.0',
|
'tensorflow >= 1.0.0',
|
||||||
'azure==1.0.3',
|
'azure==1.0.3',
|
||||||
'progressbar2',
|
'progressbar2',
|
||||||
|
'mpi4py',
|
||||||
],
|
],
|
||||||
description="OpenAI baselines: high quality implementations of reinforcement learning algorithms",
|
description="OpenAI baselines: high quality implementations of reinforcement learning algorithms",
|
||||||
author="OpenAI",
|
author="OpenAI",
|
||||||
|
Reference in New Issue
Block a user