Baselines for Tensorflow 2.0. (#978)

* Baselines for Tensorflow 2.0.

Please do note that:
1. ACER, ACKTR, GAIL is still under development by external
contributors.
2. HER is still under development by tanzheny@google.com.

* Some cleanup.

* Addressing some comments.
This commit is contained in:
tanzhenyu
2019-08-08 11:03:17 -07:00
committed by pzhokhov
parent c57528573e
commit d1a05a0dd2
138 changed files with 1216 additions and 8215 deletions

View File

@@ -150,7 +150,7 @@ respectively. Note that these results may be not on the latest version of the co
To cite this repository in publications:
@misc{baselines,
author = {Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai and Zhokhov, Peter},
author = {Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Tan, Zhenyu and Wu, Yuhuai and Zhokhov, Peter},
title = {OpenAI Baselines},
year = {2017},
publisher = {GitHub},

View File

@@ -1,22 +1,19 @@
import time
import functools
import tensorflow as tf
from baselines import logger
from baselines.common import set_global_seeds, explained_variance
from baselines.common import tf_util
from baselines.common.policies import build_policy
from baselines.common.models import get_network_builder
from baselines.common.policies import PolicyWithValue
from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.utils import InverseLinearTimeDecay
from baselines.a2c.runner import Runner
from baselines.ppo2.ppo2 import safemean
import os.path as osp
from collections import deque
from tensorflow import losses
class Model(object):
class Model(tf.keras.Model):
"""
We use this class to :
@@ -30,90 +27,42 @@ class Model(object):
save/load():
- Save load the model
"""
def __init__(self, policy, env, nsteps,
def __init__(self, *, ac_space, policy_network, nupdates,
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6)):
sess = tf_util.get_session()
nenvs = env.num_envs
nbatch = nenvs*nsteps
super(Model, self).__init__(name='A2CModel')
self.train_model = PolicyWithValue(ac_space, policy_network, value_network=None, estimate_q=False)
lr_schedule = InverseLinearTimeDecay(initial_learning_rate=lr, nupdates=nupdates)
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr_schedule, rho=alpha, epsilon=epsilon)
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.step = self.train_model.step
self.value = self.train_model.value
self.initial_state = self.train_model.initial_state
with tf.variable_scope('a2c_model', reuse=tf.AUTO_REUSE):
# step_model is used for sampling
step_model = policy(nenvs, 1, sess)
@tf.function
def train(self, obs, states, rewards, masks, actions, values):
advs = rewards - values
with tf.GradientTape() as tape:
policy_latent = self.train_model.policy_network(obs)
pd, _ = self.train_model.pdtype.pdfromlatent(policy_latent)
neglogpac = pd.neglogp(actions)
entropy = tf.reduce_mean(pd.entropy())
vpred = self.train_model.value(obs)
vf_loss = tf.reduce_mean(tf.square(vpred - rewards))
pg_loss = tf.reduce_mean(advs * neglogpac)
loss = pg_loss - entropy * self.ent_coef + vf_loss * self.vf_coef
# train_model is used to train our network
train_model = policy(nbatch, nsteps, sess)
var_list = tape.watched_variables()
grads = tape.gradient(loss, var_list)
grads, _ = tf.clip_by_global_norm(grads, self.max_grad_norm)
grads_and_vars = list(zip(grads, var_list))
self.optimizer.apply_gradients(grads_and_vars)
A = tf.placeholder(train_model.action.dtype, train_model.action.shape)
ADV = tf.placeholder(tf.float32, [nbatch])
R = tf.placeholder(tf.float32, [nbatch])
LR = tf.placeholder(tf.float32, [])
# Calculate the loss
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
# Policy loss
neglogpac = train_model.pd.neglogp(A)
# L = A(s,a) * -logpi(a|s)
pg_loss = tf.reduce_mean(ADV * neglogpac)
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
entropy = tf.reduce_mean(train_model.pd.entropy())
# Value loss
vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)
loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
# Update parameters using loss
# 1. Get the model parameters
params = find_trainable_variables("a2c_model")
# 2. Calculate the gradients
grads = tf.gradients(loss, params)
if max_grad_norm is not None:
# Clip the gradients (normalize)
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
grads = list(zip(grads, params))
# zip aggregate each gradient with parameters associated
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
# 3. Make op for one policy and value update step of A2C
trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
_train = trainer.apply_gradients(grads)
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
def train(obs, states, rewards, masks, actions, values):
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
# rewards = R + yV(s')
advs = rewards - values
for step in range(len(obs)):
cur_lr = lr.value()
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
if states is not None:
td_map[train_model.S] = states
td_map[train_model.M] = masks
policy_loss, value_loss, policy_entropy, _ = sess.run(
[pg_loss, vf_loss, entropy, _train],
td_map
)
return policy_loss, value_loss, policy_entropy
self.train = train
self.train_model = train_model
self.step_model = step_model
self.step = step_model.step
self.value = step_model.value
self.initial_state = step_model.initial_state
self.save = functools.partial(tf_util.save_variables, sess=sess)
self.load = functools.partial(tf_util.load_variables, sess=sess)
tf.global_variables_initializer().run(session=sess)
return pg_loss, vf_loss, entropy
def learn(
@@ -185,31 +134,53 @@ def learn(
set_global_seeds(seed)
total_timesteps = int(total_timesteps)
# Get the nb of env
nenvs = env.num_envs
policy = build_policy(env, network, **network_kwargs)
# Get state_space and action_space
ob_space = env.observation_space
ac_space = env.action_space
if isinstance(network, str):
network_type = network
policy_network_fn = get_network_builder(network_type)(**network_kwargs)
policy_network = policy_network_fn(ob_space.shape)
# Calculate the batch_size
nbatch = nenvs * nsteps
nupdates = total_timesteps // nbatch
# Instantiate the model object (that creates step_model and train_model)
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
model = Model(ac_space=ac_space, policy_network=policy_network, nupdates=nupdates, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps)
if load_path is not None:
model.load(load_path)
load_path = osp.expanduser(load_path)
ckpt = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None)
ckpt.restore(manager.latest_checkpoint)
# Instantiate the runner object
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
epinfobuf = deque(maxlen=100)
# Calculate the batch_size
nbatch = nenvs*nsteps
# Start total timer
tstart = time.time()
for update in range(1, total_timesteps//nbatch+1):
for update in range(1, nupdates+1):
# Get mini batch of experiences
obs, states, rewards, masks, actions, values, epinfos = runner.run()
epinfobuf.extend(epinfos)
obs = tf.constant(obs)
if states is not None:
states = tf.constant(states)
rewards = tf.constant(rewards)
masks = tf.constant(masks)
actions = tf.constant(actions)
values = tf.constant(values)
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
nseconds = time.time()-tstart

View File

@@ -1,3 +1,4 @@
import tensorflow as tf
import numpy as np
from baselines.a2c.utils import discount_with_dones
from baselines.common.runners import AbstractEnvRunner
@@ -15,40 +16,37 @@ class Runner(AbstractEnvRunner):
def __init__(self, env, model, nsteps=5, gamma=0.99):
super().__init__(env=env, model=model, nsteps=nsteps)
self.gamma = gamma
self.batch_action_shape = [x if x is not None else -1 for x in model.train_model.action.shape.as_list()]
self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype
def run(self):
# We initialize the lists that will contain the mb of experiences
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
mb_states = self.states
epinfos = []
for n in range(self.nsteps):
for _ in range(self.nsteps):
# Given observations, take action and value (V(s))
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
actions, values, states, _ = self.model.step(self.obs, S=self.states, M=self.dones)
obs = tf.constant(self.obs)
actions, values, self.states, _ = self.model.step(obs)
actions = actions._numpy()
# Append the experiences
mb_obs.append(np.copy(self.obs))
mb_obs.append(self.obs.copy())
mb_actions.append(actions)
mb_values.append(values)
mb_values.append(values._numpy())
mb_dones.append(self.dones)
# Take actions in env and look the results
obs, rewards, dones, infos = self.env.step(actions)
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
for info in infos:
maybeepinfo = info.get('episode')
if maybeepinfo: epinfos.append(maybeepinfo)
self.states = states
self.dones = dones
self.obs = obs
mb_rewards.append(rewards)
mb_dones.append(self.dones)
# Batch of steps to batch of rollouts
mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
mb_obs = sf01(np.asarray(mb_obs, dtype=self.obs.dtype))
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
mb_actions = np.asarray(mb_actions, dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
mb_actions = sf01(np.asarray(mb_actions, dtype=actions.dtype))
mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
mb_masks = mb_dones[:, :-1]
@@ -57,7 +55,7 @@ class Runner(AbstractEnvRunner):
if self.gamma > 0.0:
# Discount/bootstrap off value fn
last_values = self.model.value(self.obs, S=self.states, M=self.dones).tolist()
last_values = self.model.value(tf.constant(self.obs))._numpy().tolist()
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
rewards = rewards.tolist()
dones = dones.tolist()
@@ -68,9 +66,15 @@ class Runner(AbstractEnvRunner):
mb_rewards[n] = rewards
mb_actions = mb_actions.reshape(self.batch_action_shape)
mb_rewards = mb_rewards.flatten()
mb_values = mb_values.flatten()
mb_masks = mb_masks.flatten()
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos
def sf01(arr):
"""
swap and then flatten axes 0 and 1
"""
s = arr.shape
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])

View File

@@ -1,21 +1,5 @@
import os
import numpy as np
import tensorflow as tf
from collections import deque
def sample(logits):
noise = tf.random_uniform(tf.shape(logits))
return tf.argmax(logits - tf.log(-tf.log(noise)), 1)
def cat_entropy(logits):
a0 = logits - tf.reduce_max(logits, 1, keepdims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, 1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), 1)
def cat_entropy_softmax(p0):
return - tf.reduce_sum(p0 * tf.log(p0 + 1e-6), axis = 1)
def ortho_init(scale=1.0):
def _ortho_init(shape, dtype, partition_info=None):
@@ -34,115 +18,18 @@ def ortho_init(scale=1.0):
return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
return _ortho_init
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False):
if data_format == 'NHWC':
channel_ax = 3
strides = [1, stride, stride, 1]
bshape = [1, 1, 1, nf]
elif data_format == 'NCHW':
channel_ax = 1
strides = [1, 1, stride, stride]
bshape = [1, nf, 1, 1]
else:
raise NotImplementedError
bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1]
nin = x.get_shape()[channel_ax].value
wshape = [rf, rf, nin, nf]
with tf.variable_scope(scope):
w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0))
if not one_dim_bias and data_format == 'NHWC':
b = tf.reshape(b, bshape)
return tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) + b
def conv(scope, *, nf, rf, stride, activation, pad='valid', init_scale=1.0, data_format='channels_last'):
with tf.name_scope(scope):
layer = tf.keras.layers.Conv2D(filters=nf, kernel_size=rf, strides=stride, padding=pad,
data_format=data_format, kernel_initializer=ortho_init(init_scale))
return layer
def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
with tf.variable_scope(scope):
nin = x.get_shape()[1].value
w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias))
return tf.matmul(x, w)+b
def batch_to_seq(h, nbatch, nsteps, flat=False):
if flat:
h = tf.reshape(h, [nbatch, nsteps])
else:
h = tf.reshape(h, [nbatch, nsteps, -1])
return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
def seq_to_batch(h, flat = False):
shape = h[0].get_shape().as_list()
if not flat:
assert(len(shape) > 1)
nh = h[0].get_shape()[-1].value
return tf.reshape(tf.concat(axis=1, values=h), [-1, nh])
else:
return tf.reshape(tf.stack(values=h, axis=1), [-1])
def lstm(xs, ms, s, scope, nh, init_scale=1.0):
nbatch, nin = [v.value for v in xs[0].get_shape()]
with tf.variable_scope(scope):
wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0))
c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
for idx, (x, m) in enumerate(zip(xs, ms)):
c = c*(1-m)
h = h*(1-m)
z = tf.matmul(x, wx) + tf.matmul(h, wh) + b
i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
i = tf.nn.sigmoid(i)
f = tf.nn.sigmoid(f)
o = tf.nn.sigmoid(o)
u = tf.tanh(u)
c = f*c + i*u
h = o*tf.tanh(c)
xs[idx] = h
s = tf.concat(axis=1, values=[c, h])
return xs, s
def _ln(x, g, b, e=1e-5, axes=[1]):
u, s = tf.nn.moments(x, axes=axes, keep_dims=True)
x = (x-u)/tf.sqrt(s+e)
x = x*g+b
return x
def lnlstm(xs, ms, s, scope, nh, init_scale=1.0):
nbatch, nin = [v.value for v in xs[0].get_shape()]
with tf.variable_scope(scope):
wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
gx = tf.get_variable("gx", [nh*4], initializer=tf.constant_initializer(1.0))
bx = tf.get_variable("bx", [nh*4], initializer=tf.constant_initializer(0.0))
wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
gh = tf.get_variable("gh", [nh*4], initializer=tf.constant_initializer(1.0))
bh = tf.get_variable("bh", [nh*4], initializer=tf.constant_initializer(0.0))
b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0))
gc = tf.get_variable("gc", [nh], initializer=tf.constant_initializer(1.0))
bc = tf.get_variable("bc", [nh], initializer=tf.constant_initializer(0.0))
c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
for idx, (x, m) in enumerate(zip(xs, ms)):
c = c*(1-m)
h = h*(1-m)
z = _ln(tf.matmul(x, wx), gx, bx) + _ln(tf.matmul(h, wh), gh, bh) + b
i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
i = tf.nn.sigmoid(i)
f = tf.nn.sigmoid(f)
o = tf.nn.sigmoid(o)
u = tf.tanh(u)
c = f*c + i*u
h = o*tf.tanh(_ln(c, gc, bc))
xs[idx] = h
s = tf.concat(axis=1, values=[c, h])
return xs, s
def conv_to_fc(x):
nh = np.prod([v.value for v in x.get_shape()[1:]])
x = tf.reshape(x, [-1, nh])
return x
def fc(input_shape, scope, nh, *, init_scale=1.0, init_bias=0.0):
with tf.name_scope(scope):
layer = tf.keras.layers.Dense(units=nh, kernel_initializer=ortho_init(init_scale),
bias_initializer=tf.keras.initializers.Constant(init_bias))
layer.build(input_shape)
return layer
def discount_with_dones(rewards, dones, gamma):
discounted = []
@@ -152,131 +39,25 @@ def discount_with_dones(rewards, dones, gamma):
discounted.append(r)
return discounted[::-1]
def find_trainable_variables(key):
return tf.trainable_variables(key)
class InverseLinearTimeDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, initial_learning_rate, nupdates, name="InverseLinearTimeDecay"):
super(InverseLinearTimeDecay, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.nupdates = nupdates
self.name = name
def make_path(f):
return os.makedirs(f, exist_ok=True)
def __call__(self, step):
with tf.name_scope(self.name):
initial_learning_rate = tf.convert_to_tensor(self.initial_learning_rate, name="initial_learning_rate")
dtype = initial_learning_rate.dtype
step_t = tf.cast(step, dtype)
nupdates_t = tf.convert_to_tensor(self.nupdates, dtype=dtype)
tf.assert_less(step_t, nupdates_t)
return initial_learning_rate * (1. - step_t / nupdates_t)
def constant(p):
return 1
def linear(p):
return 1-p
def middle_drop(p):
eps = 0.75
if 1-p<eps:
return eps*0.1
return 1-p
def double_linear_con(p):
p *= 2
eps = 0.125
if 1-p<eps:
return eps
return 1-p
def double_middle_drop(p):
eps1 = 0.75
eps2 = 0.25
if 1-p<eps1:
if 1-p<eps2:
return eps2*0.5
return eps1*0.1
return 1-p
schedules = {
'linear':linear,
'constant':constant,
'double_linear_con': double_linear_con,
'middle_drop': middle_drop,
'double_middle_drop': double_middle_drop
}
class Scheduler(object):
def __init__(self, v, nvalues, schedule):
self.n = 0.
self.v = v
self.nvalues = nvalues
self.schedule = schedules[schedule]
def value(self):
current_value = self.v*self.schedule(self.n/self.nvalues)
self.n += 1.
return current_value
def value_steps(self, steps):
return self.v*self.schedule(steps/self.nvalues)
class EpisodeStats:
def __init__(self, nsteps, nenvs):
self.episode_rewards = []
for i in range(nenvs):
self.episode_rewards.append([])
self.lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
self.rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards
self.nsteps = nsteps
self.nenvs = nenvs
def feed(self, rewards, masks):
rewards = np.reshape(rewards, [self.nenvs, self.nsteps])
masks = np.reshape(masks, [self.nenvs, self.nsteps])
for i in range(0, self.nenvs):
for j in range(0, self.nsteps):
self.episode_rewards[i].append(rewards[i][j])
if masks[i][j]:
l = len(self.episode_rewards[i])
s = sum(self.episode_rewards[i])
self.lenbuffer.append(l)
self.rewbuffer.append(s)
self.episode_rewards[i] = []
def mean_length(self):
if self.lenbuffer:
return np.mean(self.lenbuffer)
else:
return 0 # on the first params dump, no episodes are finished
def mean_reward(self):
if self.rewbuffer:
return np.mean(self.rewbuffer)
else:
return 0
# For ACER
def get_by_index(x, idx):
assert(len(x.get_shape()) == 2)
assert(len(idx.get_shape()) == 1)
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
return y
def check_shape(ts,shapes):
i = 0
for (t,shape) in zip(ts,shapes):
assert t.get_shape().as_list()==shape, "id " + str(i) + " shape " + str(t.get_shape()) + str(shape)
i += 1
def avg_norm(t):
return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(t), axis=-1)))
def gradient_add(g1, g2, param):
print([g1, g2, param.name])
assert (not (g1 is None and g2 is None)), param.name
if g1 is None:
return g2
elif g2 is None:
return g1
else:
return g1 + g2
def q_explained_variance(qpred, q):
_, vary = tf.nn.moments(q, axes=[0, 1])
_, varpred = tf.nn.moments(q - qpred, axes=[0, 1])
check_shape([vary, varpred], [[]] * 2)
return 1.0 - (varpred / vary)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"nupdates": self.nupdates,
"name": self.name
}

View File

@@ -1,6 +0,0 @@
# ACER
- Original paper: https://arxiv.org/abs/1611.01224
- `python -m baselines.run --alg=acer --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
- also refer to the repo-wide [README.md](../../README.md#training-models)

View File

@@ -1,377 +0,0 @@
import time
import functools
import numpy as np
import tensorflow as tf
from baselines import logger
from baselines.common import set_global_seeds
from baselines.common.policies import build_policy
from baselines.common.tf_util import get_session, save_variables
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.a2c.utils import batch_to_seq, seq_to_batch
from baselines.a2c.utils import cat_entropy_softmax
from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.utils import EpisodeStats
from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance
from baselines.acer.buffer import Buffer
from baselines.acer.runner import Runner
# remove last step
def strip(var, nenvs, nsteps, flat = False):
vars = batch_to_seq(var, nenvs, nsteps + 1, flat)
return seq_to_batch(vars[:-1], flat)
def q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma):
"""
Calculates q_retrace targets
:param R: Rewards
:param D: Dones
:param q_i: Q values for actions taken
:param v: V values
:param rho_i: Importance weight for each action
:return: Q_retrace values
"""
rho_bar = batch_to_seq(tf.minimum(1.0, rho_i), nenvs, nsteps, True) # list of len steps, shape [nenvs]
rs = batch_to_seq(R, nenvs, nsteps, True) # list of len steps, shape [nenvs]
ds = batch_to_seq(D, nenvs, nsteps, True) # list of len steps, shape [nenvs]
q_is = batch_to_seq(q_i, nenvs, nsteps, True)
vs = batch_to_seq(v, nenvs, nsteps + 1, True)
v_final = vs[-1]
qret = v_final
qrets = []
for i in range(nsteps - 1, -1, -1):
check_shape([qret, ds[i], rs[i], rho_bar[i], q_is[i], vs[i]], [[nenvs]] * 6)
qret = rs[i] + gamma * qret * (1.0 - ds[i])
qrets.append(qret)
qret = (rho_bar[i] * (qret - q_is[i])) + vs[i]
qrets = qrets[::-1]
qret = seq_to_batch(qrets, flat=True)
return qret
# For ACER with PPO clipping instead of trust region
# def clip(ratio, eps_clip):
# # assume 0 <= eps_clip <= 1
# return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio))
class Model(object):
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, ent_coef, q_coef, gamma, max_grad_norm, lr,
rprop_alpha, rprop_epsilon, total_timesteps, lrschedule,
c, trust_region, alpha, delta):
sess = get_session()
nact = ac_space.n
nbatch = nenvs * nsteps
A = tf.placeholder(tf.int32, [nbatch]) # actions
D = tf.placeholder(tf.float32, [nbatch]) # dones
R = tf.placeholder(tf.float32, [nbatch]) # rewards, not returns
MU = tf.placeholder(tf.float32, [nbatch, nact]) # mu's
LR = tf.placeholder(tf.float32, [])
eps = 1e-6
step_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape)
train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape)
with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE):
step_model = policy(nbatch=nenvs, nsteps=1, observ_placeholder=step_ob_placeholder, sess=sess)
train_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
params = find_trainable_variables("acer_model")
print("Params {}".format(len(params)))
for var in params:
print(var)
# create polyak averaged model
ema = tf.train.ExponentialMovingAverage(alpha)
ema_apply_op = ema.apply(params)
def custom_getter(getter, *args, **kwargs):
v = ema.average(getter(*args, **kwargs))
print(v.name)
return v
with tf.variable_scope("acer_model", custom_getter=custom_getter, reuse=True):
polyak_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
# Notation: (var) = batch variable, (var)s = seqeuence variable, (var)_i = variable index by action at step i
# action probability distributions according to train_model, polyak_model and step_model
# poilcy.pi is probability distribution parameters; to obtain distribution that sums to 1 need to take softmax
train_model_p = tf.nn.softmax(train_model.pi)
polyak_model_p = tf.nn.softmax(polyak_model.pi)
step_model_p = tf.nn.softmax(step_model.pi)
v = tf.reduce_sum(train_model_p * train_model.q, axis = -1) # shape is [nenvs * (nsteps + 1)]
# strip off last step
f, f_pol, q = map(lambda var: strip(var, nenvs, nsteps), [train_model_p, polyak_model_p, train_model.q])
# Get pi and q values for actions taken
f_i = get_by_index(f, A)
q_i = get_by_index(q, A)
# Compute ratios for importance truncation
rho = f / (MU + eps)
rho_i = get_by_index(rho, A)
# Calculate Q_retrace targets
qret = q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma)
# Calculate losses
# Entropy
# entropy = tf.reduce_mean(strip(train_model.pd.entropy(), nenvs, nsteps))
entropy = tf.reduce_mean(cat_entropy_softmax(f))
# Policy Graident loss, with truncated importance sampling & bias correction
v = strip(v, nenvs, nsteps, True)
check_shape([qret, v, rho_i, f_i], [[nenvs * nsteps]] * 4)
check_shape([rho, f, q], [[nenvs * nsteps, nact]] * 2)
# Truncated importance sampling
adv = qret - v
logf = tf.log(f_i + eps)
gain_f = logf * tf.stop_gradient(adv * tf.minimum(c, rho_i)) # [nenvs * nsteps]
loss_f = -tf.reduce_mean(gain_f)
# Bias correction for the truncation
adv_bc = (q - tf.reshape(v, [nenvs * nsteps, 1])) # [nenvs * nsteps, nact]
logf_bc = tf.log(f + eps) # / (f_old + eps)
check_shape([adv_bc, logf_bc], [[nenvs * nsteps, nact]]*2)
gain_bc = tf.reduce_sum(logf_bc * tf.stop_gradient(adv_bc * tf.nn.relu(1.0 - (c / (rho + eps))) * f), axis = 1) #IMP: This is sum, as expectation wrt f
loss_bc= -tf.reduce_mean(gain_bc)
loss_policy = loss_f + loss_bc
# Value/Q function loss, and explained variance
check_shape([qret, q_i], [[nenvs * nsteps]]*2)
ev = q_explained_variance(tf.reshape(q_i, [nenvs, nsteps]), tf.reshape(qret, [nenvs, nsteps]))
loss_q = tf.reduce_mean(tf.square(tf.stop_gradient(qret) - q_i)*0.5)
# Net loss
check_shape([loss_policy, loss_q, entropy], [[]] * 3)
loss = loss_policy + q_coef * loss_q - ent_coef * entropy
if trust_region:
g = tf.gradients(- (loss_policy - ent_coef * entropy) * nsteps * nenvs, f) #[nenvs * nsteps, nact]
# k = tf.gradients(KL(f_pol || f), f)
k = - f_pol / (f + eps) #[nenvs * nsteps, nact] # Directly computed gradient of KL divergence wrt f
k_dot_g = tf.reduce_sum(k * g, axis=-1)
adj = tf.maximum(0.0, (tf.reduce_sum(k * g, axis=-1) - delta) / (tf.reduce_sum(tf.square(k), axis=-1) + eps)) #[nenvs * nsteps]
# Calculate stats (before doing adjustment) for logging.
avg_norm_k = avg_norm(k)
avg_norm_g = avg_norm(g)
avg_norm_k_dot_g = tf.reduce_mean(tf.abs(k_dot_g))
avg_norm_adj = tf.reduce_mean(tf.abs(adj))
g = g - tf.reshape(adj, [nenvs * nsteps, 1]) * k
grads_f = -g/(nenvs*nsteps) # These are turst region adjusted gradients wrt f ie statistics of policy pi
grads_policy = tf.gradients(f, params, grads_f)
grads_q = tf.gradients(loss_q * q_coef, params)
grads = [gradient_add(g1, g2, param) for (g1, g2, param) in zip(grads_policy, grads_q, params)]
avg_norm_grads_f = avg_norm(grads_f) * (nsteps * nenvs)
norm_grads_q = tf.global_norm(grads_q)
norm_grads_policy = tf.global_norm(grads_policy)
else:
grads = tf.gradients(loss, params)
if max_grad_norm is not None:
grads, norm_grads = tf.clip_by_global_norm(grads, max_grad_norm)
grads = list(zip(grads, params))
trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=rprop_alpha, epsilon=rprop_epsilon)
_opt_op = trainer.apply_gradients(grads)
# so when you call _train, you first do the gradient step, then you apply ema
with tf.control_dependencies([_opt_op]):
_train = tf.group(ema_apply_op)
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
# Ops/Summaries to run, and their names for logging
run_ops = [_train, loss, loss_q, entropy, loss_policy, loss_f, loss_bc, ev, norm_grads]
names_ops = ['loss', 'loss_q', 'entropy', 'loss_policy', 'loss_f', 'loss_bc', 'explained_variance',
'norm_grads']
if trust_region:
run_ops = run_ops + [norm_grads_q, norm_grads_policy, avg_norm_grads_f, avg_norm_k, avg_norm_g, avg_norm_k_dot_g,
avg_norm_adj]
names_ops = names_ops + ['norm_grads_q', 'norm_grads_policy', 'avg_norm_grads_f', 'avg_norm_k', 'avg_norm_g',
'avg_norm_k_dot_g', 'avg_norm_adj']
def train(obs, actions, rewards, dones, mus, states, masks, steps):
cur_lr = lr.value_steps(steps)
td_map = {train_model.X: obs, polyak_model.X: obs, A: actions, R: rewards, D: dones, MU: mus, LR: cur_lr}
if states is not None:
td_map[train_model.S] = states
td_map[train_model.M] = masks
td_map[polyak_model.S] = states
td_map[polyak_model.M] = masks
return names_ops, sess.run(run_ops, td_map)[1:] # strip off _train
def _step(observation, **kwargs):
return step_model._evaluate([step_model.action, step_model_p, step_model.state], observation, **kwargs)
self.train = train
self.save = functools.partial(save_variables, sess=sess, variables=params)
self.train_model = train_model
self.step_model = step_model
self._step = _step
self.step = self.step_model.step
self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=sess)
class Acer():
def __init__(self, runner, model, buffer, log_interval):
self.runner = runner
self.model = model
self.buffer = buffer
self.log_interval = log_interval
self.tstart = None
self.episode_stats = EpisodeStats(runner.nsteps, runner.nenv)
self.steps = None
def call(self, on_policy):
runner, model, buffer, steps = self.runner, self.model, self.buffer, self.steps
if on_policy:
enc_obs, obs, actions, rewards, mus, dones, masks = runner.run()
self.episode_stats.feed(rewards, dones)
if buffer is not None:
buffer.put(enc_obs, actions, rewards, mus, dones, masks)
else:
# get obs, actions, rewards, mus, dones from buffer.
obs, actions, rewards, mus, dones, masks = buffer.get()
# reshape stuff correctly
obs = obs.reshape(runner.batch_ob_shape)
actions = actions.reshape([runner.nbatch])
rewards = rewards.reshape([runner.nbatch])
mus = mus.reshape([runner.nbatch, runner.nact])
dones = dones.reshape([runner.nbatch])
masks = masks.reshape([runner.batch_ob_shape[0]])
names_ops, values_ops = model.train(obs, actions, rewards, dones, mus, model.initial_state, masks, steps)
if on_policy and (int(steps/runner.nbatch) % self.log_interval == 0):
logger.record_tabular("total_timesteps", steps)
logger.record_tabular("fps", int(steps/(time.time() - self.tstart)))
# IMP: In EpisodicLife env, during training, we get done=True at each loss of life, not just at the terminal state.
# Thus, this is mean until end of life, not end of episode.
# For true episode rewards, see the monitor files in the log folder.
logger.record_tabular("mean_episode_length", self.episode_stats.mean_length())
logger.record_tabular("mean_episode_reward", self.episode_stats.mean_reward())
for name, val in zip(names_ops, values_ops):
logger.record_tabular(name, float(val))
logger.dump_tabular()
def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,
trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs):
'''
Main entrypoint for ACER (Actor-Critic with Experience Replay) algorithm (https://arxiv.org/pdf/1611.01224.pdf)
Train an agent with given network architecture on a given environment using ACER.
Parameters:
----------
network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
See baselines.common/policies.py/lstm for more details on using recurrent nets in policies
env: environment. Needs to be vectorized for parallel environment simulation.
The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.
nsteps: int, number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
nenv is number of environment copies simulated in parallel) (default: 20)
nstack: int, size of the frame stack, i.e. number of the frames passed to the step model. Frames are stacked along channel dimension
(last image dimension) (default: 4)
total_timesteps: int, number of timesteps (i.e. number of actions taken in the environment) (default: 80M)
q_coef: float, value function loss coefficient in the optimization objective (analog of vf_coef for other actor-critic methods)
ent_coef: float, policy entropy coefficient in the optimization objective (default: 0.01)
max_grad_norm: float, gradient norm clipping coefficient. If set to None, no clipping. (default: 10),
lr: float, learning rate for RMSProp (current implementation has RMSProp hardcoded in) (default: 7e-4)
lrschedule: schedule of learning rate. Can be 'linear', 'constant', or a function [0..1] -> [0..1] that takes fraction of the training progress as input and
returns fraction of the learning rate (specified as lr) as output
rprop_epsilon: float, RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
rprop_alpha: float, RMSProp decay parameter (default: 0.99)
gamma: float, reward discounting factor (default: 0.99)
log_interval: int, number of updates between logging events (default: 100)
buffer_size: int, size of the replay buffer (default: 50k)
replay_ratio: int, now many (on average) batches of data to sample from the replay buffer take after batch from the environment (default: 4)
replay_start: int, the sampling from the replay buffer does not start until replay buffer has at least that many samples (default: 10k)
c: float, importance weight clipping factor (default: 10)
trust_region bool, whether or not algorithms estimates the gradient KL divergence between the old and updated policy and uses it to determine step size (default: True)
delta: float, max KL divergence between the old policy and updated policy (default: 1)
alpha: float, momentum factor in the Polyak (exponential moving average) averaging of the model parameters (default: 0.99)
load_path: str, path to load the model from (default: None)
**network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
'''
print("Running Acer Simple")
print(locals())
set_global_seeds(seed)
if not isinstance(env, VecFrameStack):
env = VecFrameStack(env, 1)
policy = build_policy(env, network, estimate_q=True, **network_kwargs)
nenvs = env.num_envs
ob_space = env.observation_space
ac_space = env.action_space
nstack = env.nstack
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps,
ent_coef=ent_coef, q_coef=q_coef, gamma=gamma,
max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon,
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
trust_region=trust_region, alpha=alpha, delta=delta)
runner = Runner(env=env, model=model, nsteps=nsteps)
if replay_ratio > 0:
buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)
else:
buffer = None
nbatch = nenvs*nsteps
acer = Acer(runner, model, buffer, log_interval)
acer.tstart = time.time()
for acer.steps in range(0, total_timesteps, nbatch): #nbatch samples, 1 on_policy call and multiple off-policy calls
acer.call(on_policy=True)
if replay_ratio > 0 and buffer.has_atleast(replay_start):
n = np.random.poisson(replay_ratio)
for _ in range(n):
acer.call(on_policy=False) # no simulation steps in this
return model

View File

@@ -1,156 +0,0 @@
import numpy as np
class Buffer(object):
# gets obs, actions, rewards, mu's, (states, masks), dones
def __init__(self, env, nsteps, size=50000):
self.nenv = env.num_envs
self.nsteps = nsteps
# self.nh, self.nw, self.nc = env.observation_space.shape
self.obs_shape = env.observation_space.shape
self.obs_dtype = env.observation_space.dtype
self.ac_dtype = env.action_space.dtype
self.nc = self.obs_shape[-1]
self.nstack = env.nstack
self.nc //= self.nstack
self.nbatch = self.nenv * self.nsteps
self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames
# Memory
self.enc_obs = None
self.actions = None
self.rewards = None
self.mus = None
self.dones = None
self.masks = None
# Size indexes
self.next_idx = 0
self.num_in_buffer = 0
def has_atleast(self, frames):
# Frames per env, so total (nenv * frames) Frames needed
# Each buffer loc has nenv * nsteps frames
return self.num_in_buffer >= (frames // self.nsteps)
def can_sample(self):
return self.num_in_buffer > 0
# Generate stacked frames
def decode(self, enc_obs, dones):
# enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc]
# dones has shape [nenvs, nsteps]
# returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc]
return _stack_obs(enc_obs, dones,
nsteps=self.nsteps)
def put(self, enc_obs, actions, rewards, mus, dones, masks):
# enc_obs [nenv, (nsteps + nstack), nh, nw, nc]
# actions, rewards, dones [nenv, nsteps]
# mus [nenv, nsteps, nact]
if self.enc_obs is None:
self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=self.obs_dtype)
self.actions = np.empty([self.size] + list(actions.shape), dtype=self.ac_dtype)
self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32)
self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32)
self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool)
self.masks = np.empty([self.size] + list(masks.shape), dtype=np.bool)
self.enc_obs[self.next_idx] = enc_obs
self.actions[self.next_idx] = actions
self.rewards[self.next_idx] = rewards
self.mus[self.next_idx] = mus
self.dones[self.next_idx] = dones
self.masks[self.next_idx] = masks
self.next_idx = (self.next_idx + 1) % self.size
self.num_in_buffer = min(self.size, self.num_in_buffer + 1)
def take(self, x, idx, envx):
nenv = self.nenv
out = np.empty([nenv] + list(x.shape[2:]), dtype=x.dtype)
for i in range(nenv):
out[i] = x[idx[i], envx[i]]
return out
def get(self):
# returns
# obs [nenv, (nsteps + 1), nh, nw, nstack*nc]
# actions, rewards, dones [nenv, nsteps]
# mus [nenv, nsteps, nact]
nenv = self.nenv
assert self.can_sample()
# Sample exactly one id per env. If you sample across envs, then higher correlation in samples from same env.
idx = np.random.randint(0, self.num_in_buffer, nenv)
envx = np.arange(nenv)
take = lambda x: self.take(x, idx, envx) # for i in range(nenv)], axis = 0)
dones = take(self.dones)
enc_obs = take(self.enc_obs)
obs = self.decode(enc_obs, dones)
actions = take(self.actions)
rewards = take(self.rewards)
mus = take(self.mus)
masks = take(self.masks)
return obs, actions, rewards, mus, dones, masks
def _stack_obs_ref(enc_obs, dones, nsteps):
nenv = enc_obs.shape[0]
nstack = enc_obs.shape[1] - nsteps
nh, nw, nc = enc_obs.shape[2:]
obs_dtype = enc_obs.dtype
obs_shape = (nh, nw, nc*nstack)
mask = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32)
obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=obs_dtype)
x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1, 0) # [nsteps + nstack, nenv, nh, nw, nc]
mask[nstack-1:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep
mask[:nstack-1] = 1.0
# y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1])
for i in range(nstack):
obs[-(i + 1), i:] = x
# obs[:,i:,:,:,-(i+1),:] = x
x = x[:-1] * mask
mask = mask[1:]
return np.reshape(obs[:, (nstack-1):].transpose((2, 1, 3, 4, 0, 5)), (nenv, (nsteps + 1)) + obs_shape)
def _stack_obs(enc_obs, dones, nsteps):
nenv = enc_obs.shape[0]
nstack = enc_obs.shape[1] - nsteps
nc = enc_obs.shape[-1]
obs_ = np.zeros((nenv, nsteps + 1) + enc_obs.shape[2:-1] + (enc_obs.shape[-1] * nstack, ), dtype=enc_obs.dtype)
mask = np.ones((nenv, nsteps+1), dtype=enc_obs.dtype)
mask[:, 1:] = 1.0 - dones
mask = mask.reshape(mask.shape + tuple(np.ones(len(enc_obs.shape)-2, dtype=np.uint8)))
for i in range(nstack-1, -1, -1):
obs_[..., i * nc : (i + 1) * nc] = enc_obs[:, i : i + nsteps + 1, :]
if i < nstack-1:
obs_[..., i * nc : (i + 1) * nc] *= mask
mask[:, 1:, ...] *= mask[:, :-1, ...]
return obs_
def test_stack_obs():
nstack = 7
nenv = 1
nsteps = 5
obs_shape = (2, 3, nstack)
enc_obs_shape = (nenv, nsteps + nstack) + obs_shape[:-1] + (1,)
enc_obs = np.random.random(enc_obs_shape)
dones = np.random.randint(low=0, high=2, size=(nenv, nsteps))
stacked_obs_ref = _stack_obs_ref(enc_obs, dones, nsteps=nsteps)
stacked_obs_test = _stack_obs(enc_obs, dones, nsteps=nsteps)
np.testing.assert_allclose(stacked_obs_ref, stacked_obs_test)

View File

@@ -1,4 +0,0 @@
def atari():
return dict(
lrschedule='constant'
)

View File

@@ -1,81 +0,0 @@
import numpy as np
import tensorflow as tf
from baselines.common.policies import nature_cnn
from baselines.a2c.utils import fc, batch_to_seq, seq_to_batch, lstm, sample
class AcerCnnPolicy(object):
def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False):
nbatch = nenv * nsteps
nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc * nstack)
nact = ac_space.n
X = tf.placeholder(tf.uint8, ob_shape) # obs
with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(X)
pi_logits = fc(h, 'pi', nact, init_scale=0.01)
pi = tf.nn.softmax(pi_logits)
q = fc(h, 'q', nact)
a = sample(tf.nn.softmax(pi_logits)) # could change this to use self.pi instead
self.initial_state = [] # not stateful
self.X = X
self.pi = pi # actual policy params now
self.pi_logits = pi_logits
self.q = q
self.vf = q
def step(ob, *args, **kwargs):
# returns actions, mus, states
a0, pi0 = sess.run([a, pi], {X: ob})
return a0, pi0, [] # dummy state
def out(ob, *args, **kwargs):
pi0, q0 = sess.run([pi, q], {X: ob})
return pi0, q0
def act(ob, *args, **kwargs):
return sess.run(a, {X: ob})
self.step = step
self.out = out
self.act = act
class AcerLstmPolicy(object):
def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False, nlstm=256):
nbatch = nenv * nsteps
nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc * nstack)
nact = ac_space.n
X = tf.placeholder(tf.uint8, ob_shape) # obs
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(X)
# lstm
xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps)
h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
h5 = seq_to_batch(h5)
pi_logits = fc(h5, 'pi', nact, init_scale=0.01)
pi = tf.nn.softmax(pi_logits)
q = fc(h5, 'q', nact)
a = sample(pi_logits) # could change this to use self.pi instead
self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)
self.X = X
self.M = M
self.S = S
self.pi = pi # actual policy params now
self.q = q
def step(ob, state, mask, *args, **kwargs):
# returns actions, mus, states
a0, pi0, s = sess.run([a, pi, snew], {X: ob, S: state, M: mask})
return a0, pi0, s
self.step = step

View File

@@ -1,61 +0,0 @@
import numpy as np
from baselines.common.runners import AbstractEnvRunner
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from gym import spaces
class Runner(AbstractEnvRunner):
def __init__(self, env, model, nsteps):
super().__init__(env=env, model=model, nsteps=nsteps)
assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!'
assert isinstance(env, VecFrameStack)
self.nact = env.action_space.n
nenv = self.nenv
self.nbatch = nenv * nsteps
self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape
self.obs = env.reset()
self.obs_dtype = env.observation_space.dtype
self.ac_dtype = env.action_space.dtype
self.nstack = self.env.nstack
self.nc = self.batch_ob_shape[-1] // self.nstack
def run(self):
# enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1)
mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
for _ in range(self.nsteps):
actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
mb_obs.append(np.copy(self.obs))
mb_actions.append(actions)
mb_mus.append(mus)
mb_dones.append(self.dones)
obs, rewards, dones, _ = self.env.step(actions)
# states information for statefull models like LSTM
self.states = states
self.dones = dones
self.obs = obs
mb_rewards.append(rewards)
enc_obs.append(obs[..., -self.nc:])
mb_obs.append(np.copy(self.obs))
mb_dones.append(self.dones)
enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0)
mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
mb_masks = mb_dones # Used for statefull models like LSTM's to mask state when done
mb_dones = mb_dones[:, 1:] # Used for calculating returns. The dones array is now aligned with rewards
# shapes are now [nenv, nsteps, []]
# When pulling from buffer, arrays will now be reshaped in place, preventing a deep copy.
return enc_obs, mb_obs, mb_actions, mb_rewards, mb_mus, mb_dones, mb_masks

View File

@@ -1,9 +0,0 @@
# ACKTR
- Original paper: https://arxiv.org/abs/1708.05144
- Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
- `python -m baselines.run --alg=acktr --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
- also refer to the repo-wide [README.md](../../README.md#training-models)
## ACKTR with continuous action spaces
The code of ACKTR has been refactored to handle both discrete and continuous action spaces uniformly. In the original version, discrete and continuous action spaces were handled by different code (actkr_disc.py and acktr_cont.py) with little overlap. If interested in the original version of the acktr for continuous action spaces, use `old_acktr_cont` branch. Note that original code performs better on the mujoco tasks than the refactored version; we are still investigating why.

View File

@@ -1,158 +0,0 @@
import os.path as osp
import time
import functools
import tensorflow as tf
from baselines import logger
from baselines.common import set_global_seeds, explained_variance
from baselines.common.policies import build_policy
from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.a2c.runner import Runner
from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.acktr import kfac
from baselines.ppo2.ppo2 import safemean
from collections import deque
class Model(object):
def __init__(self, policy, ob_space, ac_space, nenvs,total_timesteps, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, lrschedule='linear', is_async=True):
self.sess = sess = get_session()
nbatch = nenvs * nsteps
with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE):
self.model = step_model = policy(nenvs, 1, sess=sess)
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
A = train_model.pdtype.sample_placeholder([None])
ADV = tf.placeholder(tf.float32, [nbatch])
R = tf.placeholder(tf.float32, [nbatch])
PG_LR = tf.placeholder(tf.float32, [])
VF_LR = tf.placeholder(tf.float32, [])
neglogpac = train_model.pd.neglogp(A)
self.logits = train_model.pi
##training loss
pg_loss = tf.reduce_mean(ADV*neglogpac)
entropy = tf.reduce_mean(train_model.pd.entropy())
pg_loss = pg_loss - ent_coef * entropy
vf_loss = tf.losses.mean_squared_error(tf.squeeze(train_model.vf), R)
train_loss = pg_loss + vf_coef * vf_loss
##Fisher loss construction
self.pg_fisher = pg_fisher_loss = -tf.reduce_mean(neglogpac)
sample_net = train_model.vf + tf.random_normal(tf.shape(train_model.vf))
self.vf_fisher = vf_fisher_loss = - vf_fisher_coef*tf.reduce_mean(tf.pow(train_model.vf - tf.stop_gradient(sample_net), 2))
self.joint_fisher = joint_fisher_loss = pg_fisher_loss + vf_fisher_loss
self.params=params = find_trainable_variables("acktr_model")
self.grads_check = grads = tf.gradients(train_loss,params)
with tf.device('/gpu:0'):
self.optim = optim = kfac.KfacOptimizer(learning_rate=PG_LR, clip_kl=kfac_clip,\
momentum=0.9, kfac_update=1, epsilon=0.01,\
stats_decay=0.99, is_async=is_async, cold_iter=10, max_grad_norm=max_grad_norm)
# update_stats_op = optim.compute_and_apply_stats(joint_fisher_loss, var_list=params)
optim.compute_and_apply_stats(joint_fisher_loss, var_list=params)
train_op, q_runner = optim.apply_gradients(list(zip(grads,params)))
self.q_runner = q_runner
self.lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
def train(obs, states, rewards, masks, actions, values):
advs = rewards - values
for step in range(len(obs)):
cur_lr = self.lr.value()
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, PG_LR:cur_lr, VF_LR:cur_lr}
if states is not None:
td_map[train_model.S] = states
td_map[train_model.M] = masks
policy_loss, value_loss, policy_entropy, _ = sess.run(
[pg_loss, vf_loss, entropy, train_op],
td_map
)
return policy_loss, value_loss, policy_entropy
self.train = train
self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
self.train_model = train_model
self.step_model = step_model
self.step = step_model.step
self.value = step_model.value
self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=sess)
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=100, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):
set_global_seeds(seed)
if network == 'cnn':
network_kwargs['one_dim_bias'] = True
policy = build_policy(env, network, **network_kwargs)
nenvs = env.num_envs
ob_space = env.observation_space
ac_space = env.action_space
make_model = lambda : Model(policy, ob_space, ac_space, nenvs, total_timesteps, nprocs=nprocs, nsteps
=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, vf_fisher_coef=
vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip,
lrschedule=lrschedule, is_async=is_async)
if save_interval and logger.get_dir():
import cloudpickle
with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
fh.write(cloudpickle.dumps(make_model))
model = make_model()
if load_path is not None:
model.load(load_path)
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
epinfobuf = deque(maxlen=100)
nbatch = nenvs*nsteps
tstart = time.time()
coord = tf.train.Coordinator()
if is_async:
enqueue_threads = model.q_runner.create_threads(model.sess, coord=coord, start=True)
else:
enqueue_threads = []
for update in range(1, total_timesteps//nbatch+1):
obs, states, rewards, masks, actions, values, epinfos = runner.run()
epinfobuf.extend(epinfos)
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
model.old_obs = obs
nseconds = time.time()-tstart
fps = int((update*nbatch)/nseconds)
if update % log_interval == 0 or update == 1:
ev = explained_variance(values, rewards)
logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", update*nbatch)
logger.record_tabular("fps", fps)
logger.record_tabular("policy_entropy", float(policy_entropy))
logger.record_tabular("policy_loss", float(policy_loss))
logger.record_tabular("value_loss", float(value_loss))
logger.record_tabular("explained_variance", float(ev))
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
logger.dump_tabular()
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
savepath = osp.join(logger.get_dir(), 'checkpoint%.5i'%update)
print('Saving to', savepath)
model.save(savepath)
coord.request_stop()
coord.join(enqueue_threads)
return model

View File

@@ -1,5 +0,0 @@
def mujoco():
return dict(
nsteps=2500,
value_network='copy'
)

View File

@@ -1,928 +0,0 @@
import tensorflow as tf
import numpy as np
import re
# flake8: noqa F403, F405
from baselines.acktr.kfac_utils import *
from functools import reduce
KFAC_OPS = ['MatMul', 'Conv2D', 'BiasAdd']
KFAC_DEBUG = False
class KfacOptimizer():
# note that KfacOptimizer will be truly synchronous (and thus deterministic) only if a single-threaded session is used
def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, is_async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5):
self.max_grad_norm = max_grad_norm
self._lr = learning_rate
self._momentum = momentum
self._clip_kl = clip_kl
self._channel_fac = channel_fac
self._kfac_update = kfac_update
self._async = is_async
self._async_stats = async_stats
self._epsilon = epsilon
self._stats_decay = stats_decay
self._blockdiag_bias = blockdiag_bias
self._approxT2 = approxT2
self._use_float64 = use_float64
self._factored_damping = factored_damping
self._cold_iter = cold_iter
if cold_lr == None:
# good heuristics
self._cold_lr = self._lr# * 3.
else:
self._cold_lr = cold_lr
self._stats_accum_iter = stats_accum_iter
self._weight_decay_dict = weight_decay_dict
self._diag_init_coeff = 0.
self._full_stats_init = full_stats_init
if not self._full_stats_init:
self._stats_accum_iter = self._cold_iter
self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False)
self.global_step = tf.Variable(
0, name='KFAC/global_step', trainable=False)
self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False)
self.factor_step = tf.Variable(
0, name='KFAC/factor_step', trainable=False)
self.stats_step = tf.Variable(
0, name='KFAC/stats_step', trainable=False)
self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False)
self.factors = {}
self.param_vars = []
self.stats = {}
self.stats_eigen = {}
def getFactors(self, g, varlist):
graph = tf.get_default_graph()
factorTensors = {}
fpropTensors = []
bpropTensors = []
opTypes = []
fops = []
def searchFactors(gradient, graph):
# hard coded search stratergy
bpropOp = gradient.op
bpropOp_name = bpropOp.name
bTensors = []
fTensors = []
# combining additive gradient, assume they are the same op type and
# indepedent
if 'AddN' in bpropOp_name:
factors = []
for g in gradient.op.inputs:
factors.append(searchFactors(g, graph))
op_names = [item['opName'] for item in factors]
# TO-DO: need to check all the attribute of the ops as well
print (gradient.name)
print (op_names)
print (len(np.unique(op_names)))
assert len(np.unique(op_names)) == 1, gradient.name + \
' is shared among different computation OPs'
bTensors = reduce(lambda x, y: x + y,
[item['bpropFactors'] for item in factors])
if len(factors[0]['fpropFactors']) > 0:
fTensors = reduce(
lambda x, y: x + y, [item['fpropFactors'] for item in factors])
fpropOp_name = op_names[0]
fpropOp = factors[0]['op']
else:
fpropOp_name = re.search(
'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2)
fpropOp = graph.get_operation_by_name(fpropOp_name)
if fpropOp.op_def.name in KFAC_OPS:
# Known OPs
###
bTensor = [
i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1]
bTensorShape = fpropOp.outputs[0].get_shape()
if bTensor.get_shape()[0].value == None:
bTensor.set_shape(bTensorShape)
bTensors.append(bTensor)
###
if fpropOp.op_def.name == 'BiasAdd':
fTensors = []
else:
fTensors.append(
[i for i in fpropOp.inputs if param.op.name not in i.name][0])
fpropOp_name = fpropOp.op_def.name
else:
# unknown OPs, block approximation used
bInputsList = [i for i in bpropOp.inputs[
0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name]
if len(bInputsList) > 0:
bTensor = bInputsList[0]
bTensorShape = fpropOp.outputs[0].get_shape()
if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None:
bTensor.set_shape(bTensorShape)
bTensors.append(bTensor)
fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name)
return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors}
for t, param in zip(g, varlist):
if KFAC_DEBUG:
print(('get factor for '+param.name))
factors = searchFactors(t, graph)
factorTensors[param] = factors
########
# check associated weights and bias for homogeneous coordinate representation
# and check redundent factors
# TO-DO: there may be a bug to detect associate bias and weights for
# forking layer, e.g. in inception models.
for param in varlist:
factorTensors[param]['assnWeights'] = None
factorTensors[param]['assnBias'] = None
for param in varlist:
if factorTensors[param]['opName'] == 'BiasAdd':
factorTensors[param]['assnWeights'] = None
for item in varlist:
if len(factorTensors[item]['bpropFactors']) > 0:
if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0):
factorTensors[param]['assnWeights'] = item
factorTensors[item]['assnBias'] = param
factorTensors[param]['bpropFactors'] = factorTensors[
item]['bpropFactors']
########
########
# concatenate the additive gradients along the batch dimension, i.e.
# assuming independence structure
for key in ['fpropFactors', 'bpropFactors']:
for i, param in enumerate(varlist):
if len(factorTensors[param][key]) > 0:
if (key + '_concat') not in factorTensors[param]:
name_scope = factorTensors[param][key][0].name.split(':')[
0]
with tf.name_scope(name_scope):
factorTensors[param][
key + '_concat'] = tf.concat(factorTensors[param][key], 0)
else:
factorTensors[param][key + '_concat'] = None
for j, param2 in enumerate(varlist[(i + 1):]):
if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])):
factorTensors[param2][key] = factorTensors[param][key]
factorTensors[param2][
key + '_concat'] = factorTensors[param][key + '_concat']
########
if KFAC_DEBUG:
for items in zip(varlist, fpropTensors, bpropTensors, opTypes):
print((items[0].name, factorTensors[item]))
self.factors = factorTensors
return factorTensors
def getStats(self, factors, varlist):
if len(self.stats) == 0:
# initialize stats variables on CPU because eigen decomp is
# computed on CPU
with tf.device('/cpu'):
tmpStatsCache = {}
# search for tensor factors and
# use block diag approx for the bias units
for var in varlist:
fpropFactor = factors[var]['fpropFactors_concat']
bpropFactor = factors[var]['bpropFactors_concat']
opType = factors[var]['opName']
if opType == 'Conv2D':
Kh = var.get_shape()[0]
Kw = var.get_shape()[1]
C = fpropFactor.get_shape()[-1]
Oh = bpropFactor.get_shape()[1]
Ow = bpropFactor.get_shape()[2]
if Oh == 1 and Ow == 1 and self._channel_fac:
# factorization along the channels do not support
# homogeneous coordinate
var_assnBias = factors[var]['assnBias']
if var_assnBias:
factors[var]['assnBias'] = None
factors[var_assnBias]['assnWeights'] = None
##
for var in varlist:
fpropFactor = factors[var]['fpropFactors_concat']
bpropFactor = factors[var]['bpropFactors_concat']
opType = factors[var]['opName']
self.stats[var] = {'opName': opType,
'fprop_concat_stats': [],
'bprop_concat_stats': [],
'assnWeights': factors[var]['assnWeights'],
'assnBias': factors[var]['assnBias'],
}
if fpropFactor is not None:
if fpropFactor not in tmpStatsCache:
if opType == 'Conv2D':
Kh = var.get_shape()[0]
Kw = var.get_shape()[1]
C = fpropFactor.get_shape()[-1]
Oh = bpropFactor.get_shape()[1]
Ow = bpropFactor.get_shape()[2]
if Oh == 1 and Ow == 1 and self._channel_fac:
# factorization along the channels
# assume independence between input channels and spatial
# 2K-1 x 2K-1 covariance matrix and C x C covariance matrix
# factorization along the channels do not
# support homogeneous coordinate, assnBias
# is always None
fpropFactor2_size = Kh * Kw
slot_fpropFactor_stats2 = tf.Variable(tf.diag(tf.ones(
[fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
self.stats[var]['fprop_concat_stats'].append(
slot_fpropFactor_stats2)
fpropFactor_size = C
else:
# 2K-1 x 2K-1 x C x C covariance matrix
# assume BHWC
fpropFactor_size = Kh * Kw * C
else:
# D x D covariance matrix
fpropFactor_size = fpropFactor.get_shape()[-1]
# use homogeneous coordinate
if not self._blockdiag_bias and self.stats[var]['assnBias']:
fpropFactor_size += 1
slot_fpropFactor_stats = tf.Variable(tf.diag(tf.ones(
[fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
self.stats[var]['fprop_concat_stats'].append(
slot_fpropFactor_stats)
if opType != 'Conv2D':
tmpStatsCache[fpropFactor] = self.stats[
var]['fprop_concat_stats']
else:
self.stats[var][
'fprop_concat_stats'] = tmpStatsCache[fpropFactor]
if bpropFactor is not None:
# no need to collect backward stats for bias vectors if
# using homogeneous coordinates
if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']):
if bpropFactor not in tmpStatsCache:
slot_bpropFactor_stats = tf.Variable(tf.diag(tf.ones([bpropFactor.get_shape(
)[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False)
self.stats[var]['bprop_concat_stats'].append(
slot_bpropFactor_stats)
tmpStatsCache[bpropFactor] = self.stats[
var]['bprop_concat_stats']
else:
self.stats[var][
'bprop_concat_stats'] = tmpStatsCache[bpropFactor]
return self.stats
def compute_and_apply_stats(self, loss_sampled, var_list=None):
varlist = var_list
if varlist is None:
varlist = tf.trainable_variables()
stats = self.compute_stats(loss_sampled, var_list=varlist)
return self.apply_stats(stats)
def compute_stats(self, loss_sampled, var_list=None):
varlist = var_list
if varlist is None:
varlist = tf.trainable_variables()
gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
self.gs = gs
factors = self.getFactors(gs, varlist)
stats = self.getStats(factors, varlist)
updateOps = []
statsUpdates = {}
statsUpdates_cache = {}
for var in varlist:
opType = factors[var]['opName']
fops = factors[var]['op']
fpropFactor = factors[var]['fpropFactors_concat']
fpropStats_vars = stats[var]['fprop_concat_stats']
bpropFactor = factors[var]['bpropFactors_concat']
bpropStats_vars = stats[var]['bprop_concat_stats']
SVD_factors = {}
for stats_var in fpropStats_vars:
stats_var_dim = int(stats_var.get_shape()[0])
if stats_var not in statsUpdates_cache:
old_fpropFactor = fpropFactor
B = (tf.shape(fpropFactor)[0]) # batch size
if opType == 'Conv2D':
strides = fops.get_attr("strides")
padding = fops.get_attr("padding")
convkernel_size = var.get_shape()[0:3]
KH = int(convkernel_size[0])
KW = int(convkernel_size[1])
C = int(convkernel_size[2])
flatten_size = int(KH * KW * C)
Oh = int(bpropFactor.get_shape()[1])
Ow = int(bpropFactor.get_shape()[2])
if Oh == 1 and Ow == 1 and self._channel_fac:
# factorization along the channels
# assume independence among input channels
# factor = B x 1 x 1 x (KH xKW x C)
# patches = B x Oh x Ow x (KH xKW x C)
if len(SVD_factors) == 0:
if KFAC_DEBUG:
print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
# find closest rank-1 approx to the feature map
S, U, V = tf.batch_svd(tf.reshape(
fpropFactor, [-1, KH * KW, C]))
# get rank-1 approx slides
sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW
full_factor_shape = fpropFactor.get_shape()
patches_k.set_shape(
[full_factor_shape[0], KH * KW])
patches_c = V[:, :, 0] * sqrtS1 # B x C
patches_c.set_shape([full_factor_shape[0], C])
SVD_factors[C] = patches_c
SVD_factors[KH * KW] = patches_k
fpropFactor = SVD_factors[stats_var_dim]
else:
# poor mem usage implementation
patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[
0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)
if self._approxT2:
if KFAC_DEBUG:
print(('approxT2 act fisher for %s' % (var.name)))
# T^2 terms * 1/T^2, size: B x C
fpropFactor = tf.reduce_mean(patches, [1, 2])
else:
# size: (B x Oh x Ow) x C
fpropFactor = tf.reshape(
patches, [-1, flatten_size]) / Oh / Ow
fpropFactor_size = int(fpropFactor.get_shape()[-1])
if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
if opType == 'Conv2D' and not self._approxT2:
# correct padding for numerical stability (we
# divided out OhxOw from activations for T1 approx)
fpropFactor = tf.concat([fpropFactor, tf.ones(
[tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1)
else:
# use homogeneous coordinates
fpropFactor = tf.concat(
[fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1)
# average over the number of data points in a batch
# divided by B
cov = tf.matmul(fpropFactor, fpropFactor,
transpose_a=True) / tf.cast(B, tf.float32)
updateOps.append(cov)
statsUpdates[stats_var] = cov
if opType != 'Conv2D':
# HACK: for convolution we recompute fprop stats for
# every layer including forking layers
statsUpdates_cache[stats_var] = cov
for stats_var in bpropStats_vars:
stats_var_dim = int(stats_var.get_shape()[0])
if stats_var not in statsUpdates_cache:
old_bpropFactor = bpropFactor
bpropFactor_shape = bpropFactor.get_shape()
B = tf.shape(bpropFactor)[0] # batch size
C = int(bpropFactor_shape[-1]) # num channels
if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
if fpropFactor is not None:
if self._approxT2:
if KFAC_DEBUG:
print(('approxT2 grad fisher for %s' % (var.name)))
bpropFactor = tf.reduce_sum(
bpropFactor, [1, 2]) # T^2 terms * 1/T^2
else:
bpropFactor = tf.reshape(
bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms
else:
# just doing block diag approx. spatial independent
# structure does not apply here. summing over
# spatial locations
if KFAC_DEBUG:
print(('block diag approx fisher for %s' % (var.name)))
bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])
# assume sampled loss is averaged. TO-DO:figure out better
# way to handle this
bpropFactor *= tf.to_float(B)
##
cov_b = tf.matmul(
bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0])
updateOps.append(cov_b)
statsUpdates[stats_var] = cov_b
statsUpdates_cache[stats_var] = cov_b
if KFAC_DEBUG:
aKey = list(statsUpdates.keys())[0]
statsUpdates[aKey] = tf.Print(statsUpdates[aKey],
[tf.convert_to_tensor('step:'),
self.global_step,
tf.convert_to_tensor(
'computing stats'),
])
self.statsUpdates = statsUpdates
return statsUpdates
def apply_stats(self, statsUpdates):
""" compute stats and update/apply the new stats to the running average
"""
def updateAccumStats():
if self._full_stats_init:
return tf.cond(tf.greater(self.sgd_step, self._cold_iter), lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), tf.no_op)
else:
return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter))
def updateRunningAvgStats(statsUpdates, fac_iter=1):
# return tf.cond(tf.greater_equal(self.factor_step,
# tf.convert_to_tensor(fac_iter)), lambda:
# tf.group(*self._apply_stats(stats_list, varlist)), tf.no_op)
return tf.group(*self._apply_stats(statsUpdates))
if self._async_stats:
# asynchronous stats update
update_stats = self._apply_stats(statsUpdates)
queue = tf.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[
item.get_shape() for item in update_stats])
enqueue_op = queue.enqueue(update_stats)
def dequeue_stats_op():
return queue.dequeue()
self.qr_stats = tf.train.QueueRunner(queue, [enqueue_op])
update_stats_op = tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(
0)), tf.no_op, lambda: tf.group(*[dequeue_stats_op(), ]))
else:
# synchronous stats update
update_stats_op = tf.cond(tf.greater_equal(
self.stats_step, self._stats_accum_iter), lambda: updateRunningAvgStats(statsUpdates), updateAccumStats)
self._update_stats_op = update_stats_op
return update_stats_op
def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
updateOps = []
# obtain the stats var list
for stats_var in statsUpdates:
stats_new = statsUpdates[stats_var]
if accumulate:
# simple superbatch averaging
update_op = tf.assign_add(
stats_var, accumulateCoeff * stats_new, use_locking=True)
else:
# exponential running averaging
update_op = tf.assign(
stats_var, stats_var * self._stats_decay, use_locking=True)
update_op = tf.assign_add(
update_op, (1. - self._stats_decay) * stats_new, use_locking=True)
updateOps.append(update_op)
with tf.control_dependencies(updateOps):
stats_step_op = tf.assign_add(self.stats_step, 1)
if KFAC_DEBUG:
stats_step_op = (tf.Print(stats_step_op,
[tf.convert_to_tensor('step:'),
self.global_step,
tf.convert_to_tensor('fac step:'),
self.factor_step,
tf.convert_to_tensor('sgd step:'),
self.sgd_step,
tf.convert_to_tensor('Accum:'),
tf.convert_to_tensor(accumulate),
tf.convert_to_tensor('Accum coeff:'),
tf.convert_to_tensor(accumulateCoeff),
tf.convert_to_tensor('stat step:'),
self.stats_step, updateOps[0], updateOps[1]]))
return [stats_step_op, ]
def getStatsEigen(self, stats=None):
if len(self.stats_eigen) == 0:
stats_eigen = {}
if stats is None:
stats = self.stats
tmpEigenCache = {}
with tf.device('/cpu:0'):
for var in stats:
for key in ['fprop_concat_stats', 'bprop_concat_stats']:
for stats_var in stats[var][key]:
if stats_var not in tmpEigenCache:
stats_dim = stats_var.get_shape()[1].value
e = tf.Variable(tf.ones(
[stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False)
Q = tf.Variable(tf.diag(tf.ones(
[stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False)
stats_eigen[stats_var] = {'e': e, 'Q': Q}
tmpEigenCache[
stats_var] = stats_eigen[stats_var]
else:
stats_eigen[stats_var] = tmpEigenCache[
stats_var]
self.stats_eigen = stats_eigen
return self.stats_eigen
def computeStatsEigen(self):
""" compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """
# TO-DO: figure out why this op has delays (possibly moving
# eigenvectors around?)
with tf.device('/cpu:0'):
def removeNone(tensor_list):
local_list = []
for item in tensor_list:
if item is not None:
local_list.append(item)
return local_list
def copyStats(var_list):
print("copying stats to buffer tensors before eigen decomp")
redundant_stats = {}
copied_list = []
for item in var_list:
if item is not None:
if item not in redundant_stats:
if self._use_float64:
redundant_stats[item] = tf.cast(
tf.identity(item), tf.float64)
else:
redundant_stats[item] = tf.identity(item)
copied_list.append(redundant_stats[item])
else:
copied_list.append(None)
return copied_list
#stats = [copyStats(self.fStats), copyStats(self.bStats)]
#stats = [self.fStats, self.bStats]
stats_eigen = self.stats_eigen
computedEigen = {}
eigen_reverse_lookup = {}
updateOps = []
# sync copied stats
# with tf.control_dependencies(removeNone(stats[0]) +
# removeNone(stats[1])):
with tf.control_dependencies([]):
for stats_var in stats_eigen:
if stats_var not in computedEigen:
eigens = tf.self_adjoint_eig(stats_var)
e = eigens[0]
Q = eigens[1]
if self._use_float64:
e = tf.cast(e, tf.float32)
Q = tf.cast(Q, tf.float32)
updateOps.append(e)
updateOps.append(Q)
computedEigen[stats_var] = {'e': e, 'Q': Q}
eigen_reverse_lookup[e] = stats_eigen[stats_var]['e']
eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q']
self.eigen_reverse_lookup = eigen_reverse_lookup
self.eigen_update_list = updateOps
if KFAC_DEBUG:
self.eigen_update_list = [item for item in updateOps]
with tf.control_dependencies(updateOps):
updateOps.append(tf.Print(tf.constant(
0.), [tf.convert_to_tensor('computed factor eigen')]))
return updateOps
def applyStatsEigen(self, eigen_list):
updateOps = []
print(('updating %d eigenvalue/vectors' % len(eigen_list)))
for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)):
stats_eigen_var = self.eigen_reverse_lookup[mark]
updateOps.append(
tf.assign(stats_eigen_var, tensor, use_locking=True))
with tf.control_dependencies(updateOps):
factor_step_op = tf.assign_add(self.factor_step, 1)
updateOps.append(factor_step_op)
if KFAC_DEBUG:
updateOps.append(tf.Print(tf.constant(
0.), [tf.convert_to_tensor('updated kfac factors')]))
return updateOps
def getKfacPrecondUpdates(self, gradlist, varlist):
updatelist = []
vg = 0.
assert len(self.stats) > 0
assert len(self.stats_eigen) > 0
assert len(self.factors) > 0
counter = 0
grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}
for grad, var in zip(gradlist, varlist):
GRAD_RESHAPE = False
GRAD_TRANSPOSE = False
fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
bpropFactoredFishers = self.stats[var]['bprop_concat_stats']
if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
counter += 1
GRAD_SHAPE = grad.get_shape()
if len(grad.get_shape()) > 2:
# reshape conv kernel parameters
KW = int(grad.get_shape()[0])
KH = int(grad.get_shape()[1])
C = int(grad.get_shape()[2])
D = int(grad.get_shape()[3])
if len(fpropFactoredFishers) > 1 and self._channel_fac:
# reshape conv kernel parameters into tensor
grad = tf.reshape(grad, [KW * KH, C, D])
else:
# reshape conv kernel parameters into 2D grad
grad = tf.reshape(grad, [-1, D])
GRAD_RESHAPE = True
elif len(grad.get_shape()) == 1:
# reshape bias or 1D parameters
D = int(grad.get_shape()[0])
grad = tf.expand_dims(grad, 0)
GRAD_RESHAPE = True
else:
# 2D parameters
C = int(grad.get_shape()[0])
D = int(grad.get_shape()[1])
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
# use homogeneous coordinates only works for 2D grad.
# TO-DO: figure out how to factorize bias grad
# stack bias grad
var_assnBias = self.stats[var]['assnBias']
grad = tf.concat(
[grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0)
# project gradient to eigen space and reshape the eigenvalues
# for broadcasting
eigVals = []
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
e = detectMinVal(self.stats_eigen[stats][
'e'], var, name='act', debug=KFAC_DEBUG)
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
eigVals.append(e)
grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
e = detectMinVal(self.stats_eigen[stats][
'e'], var, name='grad', debug=KFAC_DEBUG)
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
eigVals.append(e)
grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
##
#####
# whiten using eigenvalues
weightDecayCoeff = 0.
if var in self._weight_decay_dict:
weightDecayCoeff = self._weight_decay_dict[var]
if KFAC_DEBUG:
print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff)))
if self._factored_damping:
if KFAC_DEBUG:
print(('use factored damping for %s' % (var.name)))
coeffs = 1.
num_factors = len(eigVals)
# compute the ratio of two trace norm of the left and right
# KFac matrices, and their generalization
if len(eigVals) == 1:
damping = self._epsilon + weightDecayCoeff
else:
damping = tf.pow(
self._epsilon + weightDecayCoeff, 1. / num_factors)
eigVals_tnorm_avg = [tf.reduce_mean(
tf.abs(e)) for e in eigVals]
for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
eig_tnorm_negList = [
item for item in eigVals_tnorm_avg if item != e_tnorm]
if len(eigVals) == 1:
adjustment = 1.
elif len(eigVals) == 2:
adjustment = tf.sqrt(
e_tnorm / eig_tnorm_negList[0])
else:
eig_tnorm_negList_prod = reduce(
lambda x, y: x * y, eig_tnorm_negList)
adjustment = tf.pow(
tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors)
coeffs *= (e + adjustment * damping)
else:
coeffs = 1.
damping = (self._epsilon + weightDecayCoeff)
for e in eigVals:
coeffs *= e
coeffs += damping
#grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])
grad /= coeffs
#grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
#####
# project gradient back to euclidean space
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
##
#grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
# use homogeneous coordinates only works for 2D grad.
# TO-DO: figure out how to factorize bias grad
# un-stack bias grad
var_assnBias = self.stats[var]['assnBias']
C_plus_one = int(grad.get_shape()[0])
grad_assnBias = tf.reshape(tf.slice(grad,
begin=[
C_plus_one - 1, 0],
size=[1, -1]), var_assnBias.get_shape())
grad_assnWeights = tf.slice(grad,
begin=[0, 0],
size=[C_plus_one - 1, -1])
grad_dict[var_assnBias] = grad_assnBias
grad = grad_assnWeights
#grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
if GRAD_RESHAPE:
grad = tf.reshape(grad, GRAD_SHAPE)
grad_dict[var] = grad
print(('projecting %d gradient matrices' % counter))
for g, var in zip(gradlist, varlist):
grad = grad_dict[var]
### clipping ###
if KFAC_DEBUG:
print(('apply clipping to %s' % (var.name)))
tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))], "Euclidean norm of new grad")
local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr))
vg += local_vg
# recale everything
if KFAC_DEBUG:
print('apply vFv clipping')
scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
if KFAC_DEBUG:
scaling = tf.Print(scaling, [tf.convert_to_tensor(
'clip: '), scaling, tf.convert_to_tensor(' vFv: '), vg])
with tf.control_dependencies([tf.assign(self.vFv, vg)]):
updatelist = [grad_dict[var] for var in varlist]
for i, item in enumerate(updatelist):
updatelist[i] = scaling * item
return updatelist
def compute_gradients(self, loss, var_list=None):
varlist = var_list
if varlist is None:
varlist = tf.trainable_variables()
g = tf.gradients(loss, varlist)
return [(a, b) for a, b in zip(g, varlist)]
def apply_gradients_kfac(self, grads):
g, varlist = list(zip(*grads))
if len(self.stats_eigen) == 0:
self.getStatsEigen()
qr = None
# launch eigen-decomp on a queue thread
if self._async:
print('Use async eigen decomp')
# get a list of factor loading tensors
factorOps_dummy = self.computeStatsEigen()
# define a queue for the list of factor loading tensors
queue = tf.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[
item.get_shape() for item in factorOps_dummy])
enqueue_op = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op)
def dequeue_op():
return queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue_op])
updateOps = []
global_step_op = tf.assign_add(self.global_step, 1)
updateOps.append(global_step_op)
with tf.control_dependencies([global_step_op]):
# compute updates
assert self._update_stats_op != None
updateOps.append(self._update_stats_op)
dependency_list = []
if not self._async:
dependency_list.append(self._update_stats_op)
with tf.control_dependencies(dependency_list):
def no_op_wrapper():
return tf.group(*[tf.assign_add(self.cold_step, 1)])
if not self._async:
# synchronous eigen-decomp updates
updateFactorOps = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update),
tf.convert_to_tensor(0)),
tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), no_op_wrapper)
else:
# asynchronous eigen-decomp updates using queue
updateFactorOps = tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter),
lambda: tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(0)),
tf.no_op,
lambda: tf.group(
*self.applyStatsEigen(dequeue_op())),
),
no_op_wrapper)
updateOps.append(updateFactorOps)
with tf.control_dependencies([updateFactorOps]):
def gradOp():
return list(g)
def getKfacGradOp():
return self.getKfacPrecondUpdates(g, varlist)
u = tf.cond(tf.greater(self.factor_step,
tf.convert_to_tensor(0)), getKfacGradOp, gradOp)
optim = tf.train.MomentumOptimizer(
self._lr * (1. - self._momentum), self._momentum)
#optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)
def optimOp():
def updateOptimOp():
if self._full_stats_init:
return tf.cond(tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients(list(zip(u, varlist))), tf.no_op)
else:
return optim.apply_gradients(list(zip(u, varlist)))
if self._full_stats_init:
return tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op)
else:
return tf.cond(tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op)
updateOps.append(optimOp())
return tf.group(*updateOps), qr
def apply_gradients(self, grads):
coldOptim = tf.train.MomentumOptimizer(
self._cold_lr, self._momentum)
def coldSGDstart():
sgd_grads, sgd_var = zip(*grads)
if self.max_grad_norm != None:
sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm)
sgd_grads = list(zip(sgd_grads,sgd_var))
sgd_step_op = tf.assign_add(self.sgd_step, 1)
coldOptim_op = coldOptim.apply_gradients(sgd_grads)
if KFAC_DEBUG:
with tf.control_dependencies([sgd_step_op, coldOptim_op]):
sgd_step_op = tf.Print(
sgd_step_op, [self.sgd_step, tf.convert_to_tensor('doing cold sgd step')])
return tf.group(*[sgd_step_op, coldOptim_op])
kfacOptim_op, qr = self.apply_gradients_kfac(grads)
def warmKFACstart():
return kfacOptim_op
return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr
def minimize(self, loss, loss_sampled, var_list=None):
grads = self.compute_gradients(loss, var_list=var_list)
update_stats_op = self.compute_and_apply_stats(
loss_sampled, var_list=var_list)
return self.apply_gradients(grads)

View File

@@ -1,86 +0,0 @@
import tensorflow as tf
def gmatmul(a, b, transpose_a=False, transpose_b=False, reduce_dim=None):
assert reduce_dim is not None
# weird batch matmul
if len(a.get_shape()) == 2 and len(b.get_shape()) > 2:
# reshape reduce_dim to the left most dim in b
b_shape = b.get_shape()
if reduce_dim != 0:
b_dims = list(range(len(b_shape)))
b_dims.remove(reduce_dim)
b_dims.insert(0, reduce_dim)
b = tf.transpose(b, b_dims)
b_t_shape = b.get_shape()
b = tf.reshape(b, [int(b_shape[reduce_dim]), -1])
result = tf.matmul(a, b, transpose_a=transpose_a,
transpose_b=transpose_b)
result = tf.reshape(result, b_t_shape)
if reduce_dim != 0:
b_dims = list(range(len(b_shape)))
b_dims.remove(0)
b_dims.insert(reduce_dim, 0)
result = tf.transpose(result, b_dims)
return result
elif len(a.get_shape()) > 2 and len(b.get_shape()) == 2:
# reshape reduce_dim to the right most dim in a
a_shape = a.get_shape()
outter_dim = len(a_shape) - 1
reduce_dim = len(a_shape) - reduce_dim - 1
if reduce_dim != outter_dim:
a_dims = list(range(len(a_shape)))
a_dims.remove(reduce_dim)
a_dims.insert(outter_dim, reduce_dim)
a = tf.transpose(a, a_dims)
a_t_shape = a.get_shape()
a = tf.reshape(a, [-1, int(a_shape[reduce_dim])])
result = tf.matmul(a, b, transpose_a=transpose_a,
transpose_b=transpose_b)
result = tf.reshape(result, a_t_shape)
if reduce_dim != outter_dim:
a_dims = list(range(len(a_shape)))
a_dims.remove(outter_dim)
a_dims.insert(reduce_dim, outter_dim)
result = tf.transpose(result, a_dims)
return result
elif len(a.get_shape()) == 2 and len(b.get_shape()) == 2:
return tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b)
assert False, 'something went wrong'
def clipoutNeg(vec, threshold=1e-6):
mask = tf.cast(vec > threshold, tf.float32)
return mask * vec
def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False):
eigen_min = tf.reduce_min(input_mat)
eigen_max = tf.reduce_max(input_mat)
eigen_ratio = eigen_max / eigen_min
input_mat_clipped = clipoutNeg(input_mat, threshold)
if debug:
input_mat_clipped = tf.cond(tf.logical_or(tf.greater(eigen_ratio, 0.), tf.less(eigen_ratio, -500)), lambda: input_mat_clipped, lambda: tf.Print(
input_mat_clipped, [tf.convert_to_tensor('screwed ratio ' + name + ' eigen values!!!'), tf.convert_to_tensor(var.name), eigen_min, eigen_max, eigen_ratio]))
return input_mat_clipped
def factorReshape(Q, e, grad, facIndx=0, ftype='act'):
grad_shape = grad.get_shape()
if ftype == 'act':
assert e.get_shape()[0] == grad_shape[facIndx]
expanded_shape = [1, ] * len(grad_shape)
expanded_shape[facIndx] = -1
e = tf.reshape(e, expanded_shape)
if ftype == 'grad':
assert e.get_shape()[0] == grad_shape[len(grad_shape) - facIndx - 1]
expanded_shape = [1, ] * len(grad_shape)
expanded_shape[len(grad_shape) - facIndx - 1] = -1
e = tf.reshape(e, expanded_shape)
return Q, e

View File

@@ -1,28 +0,0 @@
import tensorflow as tf
def dense(x, size, name, weight_init=None, bias_init=0, weight_loss_dict=None, reuse=None):
with tf.variable_scope(name, reuse=reuse):
assert (len(tf.get_variable_scope().name.split('/')) == 2)
w = tf.get_variable("w", [x.get_shape()[1], size], initializer=weight_init)
b = tf.get_variable("b", [size], initializer=tf.constant_initializer(bias_init))
weight_decay_fc = 3e-4
if weight_loss_dict is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(w), weight_decay_fc, name='weight_decay_loss')
if weight_loss_dict is not None:
weight_loss_dict[w] = weight_decay_fc
weight_loss_dict[b] = 0.0
tf.add_to_collection(tf.get_variable_scope().name.split('/')[0] + '_' + 'losses', weight_decay)
return tf.nn.bias_add(tf.matmul(x, w), b)
def kl_div(action_dist1, action_dist2, action_size):
mean1, std1 = action_dist1[:, :action_size], action_dist1[:, action_size:]
mean2, std2 = action_dist2[:, :action_size], action_dist2[:, action_size:]
numerator = tf.square(mean1 - mean2) + tf.square(std1) - tf.square(std2)
denominator = 2 * tf.square(std2) + 1e-8
return tf.reduce_sum(
numerator/denominator + tf.log(std2) - tf.log(std1),reduction_indices=-1)

View File

@@ -1,3 +1,2 @@
# flake8: noqa F403
from baselines.bench.benchmarks import *
from baselines.bench.monitor import *

View File

@@ -1,31 +0,0 @@
from .monitor import Monitor
import gym
import json
def test_monitor():
import pandas
import os
import uuid
env = gym.make("CartPole-v1")
env.seed(0)
mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4()
menv = Monitor(env, mon_file)
menv.reset()
for _ in range(1000):
_, _, done, _ = menv.step(0)
if done:
menv.reset()
f = open(mon_file, 'rt')
firstline = f.readline()
assert firstline.startswith('#')
metadata = json.loads(firstline[1:])
assert metadata['env_id'] == "CartPole-v1"
assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata"
last_logline = pandas.read_csv(f, index_col=None)
assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
f.close()
os.remove(mon_file)

View File

@@ -17,26 +17,21 @@ from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common import retro_wrappers
from baselines.common.wrappers import ClipActionsWrapper
def make_vec_env(env_id, env_type, num_env, seed,
wrapper_kwargs=None,
env_kwargs=None,
start_index=0,
reward_scale=1.0,
flatten_dict_observations=True,
gamestate=None,
initializer=None,
force_dummy=False):
gamestate=None):
"""
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
"""
wrapper_kwargs = wrapper_kwargs or {}
env_kwargs = env_kwargs or {}
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
seed = seed + 10000 * mpi_rank if seed is not None else None
logger_dir = logger.get_dir()
def make_thunk(rank, initializer=None):
def make_thunk(rank):
return lambda: make_env(
env_id=env_id,
env_type=env_type,
@@ -47,30 +42,18 @@ def make_vec_env(env_id, env_type, num_env, seed,
gamestate=gamestate,
flatten_dict_observations=flatten_dict_observations,
wrapper_kwargs=wrapper_kwargs,
env_kwargs=env_kwargs,
logger_dir=logger_dir,
initializer=initializer
logger_dir=logger_dir
)
set_global_seeds(seed)
if not force_dummy and num_env > 1:
return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)])
if num_env > 1:
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
else:
return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)])
return DummyVecEnv([make_thunk(start_index)])
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
if initializer is not None:
initializer(mpi_rank=mpi_rank, subrank=subrank)
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
wrapper_kwargs = wrapper_kwargs or {}
env_kwargs = env_kwargs or {}
if ':' in env_id:
import re
import importlib
module_name = re.sub(':.*','',env_id)
env_id = re.sub('.*:', '', env_id)
importlib.import_module(module_name)
if env_type == 'atari':
env = make_atari(env_id)
elif env_type == 'retro':
@@ -78,7 +61,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
gamestate = gamestate or retro.State.DEFAULT
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
else:
env = gym.make(env_id, **env_kwargs)
env = gym.make(env_id)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
keys = env.observation_space.spaces.keys()
@@ -89,7 +72,6 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
allow_early_resets=True)
if env_type == 'atari':
env = wrap_deepmind(env, **wrapper_kwargs)
elif env_type == 'retro':
@@ -97,9 +79,6 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
wrapper_kwargs['frame_stack'] = 1
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
if isinstance(env.action_space, gym.spaces.Box):
env = ClipActionsWrapper(env)
if reward_scale != 1:
env = retro_wrappers.RewardScaler(env, reward_scale)

View File

@@ -1,8 +1,6 @@
import tensorflow as tf
import numpy as np
import baselines.common.tf_util as U
from baselines.a2c.utils import fc
from tensorflow.python.ops import math_ops
class Pd(object):
"""
@@ -31,7 +29,7 @@ class Pd(object):
def __getitem__(self, idx):
return self.__class__(self.flatparam()[idx])
class PdType(object):
class PdType(tf.Module):
"""
Parametrized family of probability distributions
"""
@@ -39,7 +37,7 @@ class PdType(object):
raise NotImplementedError
def pdfromflat(self, flat):
return self.pdclass()(flat)
def pdfromlatent(self, latent_vector, init_scale, init_bias):
def pdfromlatent(self, latent_vector):
raise NotImplementedError
def param_shape(self):
raise NotImplementedError
@@ -48,21 +46,18 @@ class PdType(object):
def sample_dtype(self):
raise NotImplementedError
def param_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
def sample_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
def __eq__(self, other):
return (type(self) == type(other)) and (self.__dict__ == other.__dict__)
class CategoricalPdType(PdType):
def __init__(self, ncat):
def __init__(self, latent_shape, ncat, init_scale=1.0, init_bias=0.0):
self.ncat = ncat
self.matching_fc = _matching_fc(latent_shape, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
def pdclass(self):
return CategoricalPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
def pdfromlatent(self, latent_vector):
pdparam = self.matching_fc(latent_vector)
return self.pdfromflat(pdparam), pdparam
def param_shape(self):
@@ -72,37 +67,18 @@ class CategoricalPdType(PdType):
def sample_dtype(self):
return tf.int32
class MultiCategoricalPdType(PdType):
def __init__(self, nvec):
self.ncats = nvec.astype('int32')
assert (self.ncats > 0).all()
def pdclass(self):
return MultiCategoricalPd
def pdfromflat(self, flat):
return MultiCategoricalPd(self.ncats, flat)
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [len(self.ncats)]
def sample_dtype(self):
return tf.int32
class DiagGaussianPdType(PdType):
def __init__(self, size):
def __init__(self, latent_shape, size, init_scale=1.0, init_bias=0.0):
self.size = size
self.matching_fc = _matching_fc(latent_shape, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
self.logstd = tf.Variable(np.zeros((1, self.size)), name='pi/logstd', dtype=tf.float32)
def pdclass(self):
return DiagGaussianPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
def pdfromlatent(self, latent_vector):
mean = self.matching_fc(latent_vector)
pdparam = tf.concat([mean, mean * 0.0 + self.logstd], axis=1)
return self.pdfromflat(pdparam), mean
def param_shape(self):
@@ -112,43 +88,6 @@ class DiagGaussianPdType(PdType):
def sample_dtype(self):
return tf.float32
class BernoulliPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return BernoulliPd
def param_shape(self):
return [self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.int32
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
# WRONG SECOND DERIVATIVES
# class CategoricalPd(Pd):
# def __init__(self, logits):
# self.logits = logits
# self.ps = tf.nn.softmax(logits)
# @classmethod
# def fromflat(cls, flat):
# return cls(flat)
# def flatparam(self):
# return self.logits
# def mode(self):
# return U.argmax(self.logits, axis=-1)
# def logp(self, x):
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
# def kl(self, other):
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def entropy(self):
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def sample(self):
# u = tf.random_uniform(tf.shape(self.logits))
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
class CategoricalPd(Pd):
def __init__(self, logits):
@@ -161,6 +100,7 @@ class CategoricalPd(Pd):
@property
def mean(self):
return tf.nn.softmax(self.logits)
def neglogp(self, x):
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
@@ -176,11 +116,11 @@ class CategoricalPd(Pd):
x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
else:
# already encoded
print('logits is {}'.format(self.logits))
assert x.shape.as_list() == self.logits.shape.as_list()
return tf.nn.softmax_cross_entropy_with_logits_v2(
logits=self.logits,
labels=x)
return tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
def kl(self, other):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True)
@@ -189,41 +129,20 @@ class CategoricalPd(Pd):
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)
return tf.reduce_sum(p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=-1)
def entropy(self):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)
return tf.reduce_sum(p0 * (tf.math.log(z0) - a0), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
u = tf.random.uniform(tf.shape(self.logits), dtype=self.logits.dtype, seed=0)
return tf.argmax(self.logits - tf.math.log(-tf.math.log(u)), axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class MultiCategoricalPd(Pd):
def __init__(self, nvec, flat):
self.flat = flat
self.categoricals = list(map(CategoricalPd,
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
def flatparam(self):
return self.flat
def mode(self):
return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
def neglogp(self, x):
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
def kl(self, other):
return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
@classmethod
def fromflat(cls, flat):
raise NotImplementedError
class DiagGaussianPd(Pd):
def __init__(self, flat):
self.flat = flat
@@ -237,7 +156,7 @@ class DiagGaussianPd(Pd):
return self.mean
def neglogp(self, x):
return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
+ 0.5 * np.log(2.0 * np.pi) * tf.cast(tf.shape(x)[-1], dtype=tf.float32) \
+ tf.reduce_sum(self.logstd, axis=-1)
def kl(self, other):
assert isinstance(other, DiagGaussianPd)
@@ -245,111 +164,23 @@ class DiagGaussianPd(Pd):
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
return self.mean + self.std * tf.random.normal(tf.shape(self.mean))
@classmethod
def fromflat(cls, flat):
return cls(flat)
class BernoulliPd(Pd):
def __init__(self, logits):
self.logits = logits
self.ps = tf.sigmoid(logits)
def flatparam(self):
return self.logits
@property
def mean(self):
return self.ps
def mode(self):
return tf.round(self.ps)
def neglogp(self, x):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1)
def kl(self, other):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def entropy(self):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.ps))
return tf.to_float(math_ops.less(u, self.ps))
@classmethod
def fromflat(cls, flat):
return cls(flat)
def make_pdtype(ac_space):
def make_pdtype(latent_shape, ac_space, init_scale=1.0):
from gym import spaces
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1
return DiagGaussianPdType(ac_space.shape[0])
return DiagGaussianPdType(latent_shape, ac_space.shape[0], init_scale)
elif isinstance(ac_space, spaces.Discrete):
return CategoricalPdType(ac_space.n)
elif isinstance(ac_space, spaces.MultiDiscrete):
return MultiCategoricalPdType(ac_space.nvec)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
return CategoricalPdType(latent_shape, ac_space.n, init_scale)
else:
raise NotImplementedError
raise ValueError('No implementation for {}'.format(ac_space))
def shape_el(v, i):
maybe = v.get_shape()[i]
if maybe is not None:
return maybe
def _matching_fc(tensor_shape, name, size, init_scale, init_bias):
if tensor_shape[-1] == size:
return lambda x: x
else:
return tf.shape(v)[i]
@U.in_session
def test_probtypes():
np.random.seed(0)
pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
validate_probtype(diag_gauss, pdparam_diag_gauss)
pdparam_categorical = np.array([-.2, .3, .5])
categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
validate_probtype(categorical, pdparam_categorical)
nvec = [1,2,3]
pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101
validate_probtype(multicategorical, pdparam_multicategorical)
pdparam_bernoulli = np.array([-.2, .3, .5])
bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
validate_probtype(bernoulli, pdparam_bernoulli)
def validate_probtype(probtype, pdparam):
N = 100000
# Check to see if mean negative log likelihood == differential entropy
Mval = np.repeat(pdparam[None, :], N, axis=0)
M = probtype.param_placeholder([N])
X = probtype.sample_placeholder([N])
pd = probtype.pdfromflat(M)
calcloglik = U.function([X, M], pd.logp(X))
calcent = U.function([M], pd.entropy())
Xval = tf.get_default_session().run(pd.sample(), feed_dict={M:Mval})
logliks = calcloglik(Xval, Mval)
entval_ll = - logliks.mean() #pylint: disable=E1101
entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
entval = calcent(Mval).mean() #pylint: disable=E1101
assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
# Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
M2 = probtype.param_placeholder([N])
pd2 = probtype.pdfromflat(M2)
q = pdparam + np.random.randn(pdparam.size) * 0.1
Mval2 = np.repeat(q[None, :], N, axis=0)
calckl = U.function([M, M2], pd.kl(pd2))
klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
logliks = calcloglik(Xval, Mval2)
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
print('ok on', probtype, pdparam)
def _matching_fc(tensor, name, size, init_scale, init_bias):
if tensor.shape[-1] == size:
return tensor
else:
return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)
return fc(tensor_shape, name, size, init_scale=init_scale, init_bias=init_bias)

View File

@@ -1,64 +0,0 @@
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space
Parameters:
----------
ob_space: gym.Space observation space
batch_size: int size of the batch to be fed into input. Can be left None in most cases.
name: str name of the placeholder
Returns:
-------
tensorflow placeholder tensor
'''
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
'Can only deal with Discrete and Box observation spaces for now'
dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
def observation_input(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space, and add input
encoder of the appropriate type.
'''
placeholder = observation_placeholder(ob_space, batch_size, name)
return placeholder, encode_observation(ob_space, placeholder)
def encode_observation(ob_space, placeholder):
'''
Encode input in the way that is appropriate to the observation space
Parameters:
----------
ob_space: gym.Space observation space
placeholder: tf.placeholder observation input placeholder
'''
if isinstance(ob_space, Discrete):
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
elif isinstance(ob_space, Box):
return tf.to_float(placeholder)
elif isinstance(ob_space, MultiDiscrete):
placeholder = tf.cast(placeholder, tf.int32)
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
else:
raise NotImplementedError

View File

@@ -55,7 +55,7 @@ def set_global_seeds(i):
myseed = i + 1000 * rank if i is not None else None
try:
import tensorflow as tf
tf.set_random_seed(myseed)
tf.random.set_seed(myseed)
except ImportError:
pass
np.random.seed(myseed)

View File

@@ -1,8 +1,6 @@
import numpy as np
import tensorflow as tf
from baselines.a2c import utils
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
from baselines.common.mpi_running_mean_std import RunningMeanStd
from baselines.a2c.utils import ortho_init, conv
mapping = {}
@@ -12,67 +10,26 @@ def register(name):
return func
return _thunk
def nature_cnn(unscaled_images, **conv_kwargs):
def nature_cnn(input_shape, **conv_kwargs):
"""
CNN from Nature paper.
"""
scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
activ = tf.nn.relu
h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2),
**conv_kwargs))
h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs))
h3 = conv_to_fc(h3)
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
def build_impala_cnn(unscaled_images, depths=[16,32,32], **conv_kwargs):
"""
Model used in the paper "IMPALA: Scalable Distributed Deep-RL with
Importance Weighted Actor-Learner Architectures" https://arxiv.org/abs/1802.01561
"""
layer_num = 0
def get_layer_num_str():
nonlocal layer_num
num_str = str(layer_num)
layer_num += 1
return num_str
def conv_layer(out, depth):
return tf.layers.conv2d(out, depth, 3, padding='same', name='layer_' + get_layer_num_str())
def residual_block(inputs):
depth = inputs.get_shape()[-1].value
out = tf.nn.relu(inputs)
out = conv_layer(out, depth)
out = tf.nn.relu(out)
out = conv_layer(out, depth)
return out + inputs
def conv_sequence(inputs, depth):
out = conv_layer(inputs, depth)
out = tf.layers.max_pooling2d(out, pool_size=3, strides=2, padding='same')
out = residual_block(out)
out = residual_block(out)
return out
out = tf.cast(unscaled_images, tf.float32) / 255.
for depth in depths:
out = conv_sequence(out, depth)
out = tf.layers.flatten(out)
out = tf.nn.relu(out)
out = tf.layers.dense(out, 256, activation=tf.nn.relu, name='layer_' + get_layer_num_str())
return out
print('input shape is {}'.format(input_shape))
x_input = tf.keras.Input(shape=input_shape, dtype=tf.uint8)
h = x_input
h = tf.cast(h, tf.float32) / 255.
h = conv('c1', nf=32, rf=8, stride=4, activation='relu', init_scale=np.sqrt(2))(h)
h2 = conv('c2', nf=64, rf=4, stride=2, activation='relu', init_scale=np.sqrt(2))(h)
h3 = conv('c3', nf=64, rf=3, stride=1, activation='relu', init_scale=np.sqrt(2))(h2)
h3 = tf.keras.layers.Flatten()(h3)
h3 = tf.keras.layers.Dense(units=512, kernel_initializer=ortho_init(np.sqrt(2)),
name='fc1', activation='relu')(h3)
network = tf.keras.Model(inputs=[x_input], outputs=[h3])
return network
@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
"""
Stack of fully-connected layers to be used in a policy / q-function approximator
@@ -90,169 +47,54 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
function that builds fully connected network with a given input tensor / placeholder
"""
def network_fn(X):
h = tf.layers.flatten(X)
def network_fn(input_shape):
print('input shape is {}'.format(input_shape))
x_input = tf.keras.Input(shape=input_shape)
# h = tf.keras.layers.Flatten(x_input)
h = x_input
for i in range(num_layers):
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
if layer_norm:
h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
h = activation(h)
h = tf.keras.layers.Dense(units=num_hidden, kernel_initializer=ortho_init(np.sqrt(2)),
name='mlp_fc{}'.format(i), activation=activation)(h)
return h
network = tf.keras.Model(inputs=[x_input], outputs=[h])
return network
return network_fn
@register("cnn")
def cnn(**conv_kwargs):
def network_fn(X):
return nature_cnn(X, **conv_kwargs)
def network_fn(input_shape):
return nature_cnn(input_shape, **conv_kwargs)
return network_fn
@register("impala_cnn")
def impala_cnn(**conv_kwargs):
def network_fn(X):
return build_impala_cnn(X)
return network_fn
@register("cnn_small")
def cnn_small(**conv_kwargs):
def network_fn(X):
h = tf.cast(X, tf.float32) / 255.
activ = tf.nn.relu
h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs))
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
h = conv_to_fc(h)
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
return h
return network_fn
@register("lstm")
def lstm(nlstm=128, layer_norm=False):
"""
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
Note that the resulting function returns not only the output of the LSTM
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
with auxiliary tensors to be set as policy attributes.
Specifically,
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
initial_state is a numpy array containing initial lstm state (usually zeros)
state is the output LSTM state (to be fed into S at the next call)
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
Parameters:
----------
nlstm: int LSTM hidden state size
layer_norm: bool if True, layer-normalized version of LSTM is used
Returns:
-------
function that builds LSTM with a given input tensor / placeholder
"""
def network_fn(X, nenv=1):
nbatch = X.shape[0]
nsteps = nbatch // nenv
h = tf.layers.flatten(X)
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps)
if layer_norm:
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
else:
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float)
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
return network_fn
@register("cnn_lstm")
def cnn_lstm(nlstm=128, layer_norm=False, conv_fn=nature_cnn, **conv_kwargs):
def network_fn(X, nenv=1):
nbatch = X.shape[0]
nsteps = nbatch // nenv
h = conv_fn(X, **conv_kwargs)
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps)
if layer_norm:
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
else:
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float)
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
return network_fn
@register("impala_cnn_lstm")
def impala_cnn_lstm():
return cnn_lstm(nlstm=256, conv_fn=build_impala_cnn)
@register("cnn_lnlstm")
def cnn_lnlstm(nlstm=128, **conv_kwargs):
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
'''
convolutions-only net
Parameters:
----------
conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
Returns:
function that takes tensorflow tensor as input and returns the output of the last convolutional layer
'''
def network_fn(X):
out = tf.cast(X, tf.float32) / 255.
with tf.variable_scope("convnet"):
def network_fn(input_shape):
print('input shape is {}'.format(input_shape))
x_input = tf.keras.Input(shape=input_shape, dtype=tf.uint8)
h = x_input
h = tf.cast(h, tf.float32) / 255.
with tf.name_scope("convnet"):
for num_outputs, kernel_size, stride in convs:
out = tf.contrib.layers.convolution2d(out,
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
activation_fn=tf.nn.relu,
**conv_kwargs)
h = tf.keras.layers.Conv2D(
filters=num_outputs, kernel_size=kernel_size, strides=stride,
activation='relu', **conv_kwargs)(h)
return out
network = tf.keras.Model(inputs=[x_input], outputs=[h])
return network
return network_fn
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
rms = RunningMeanStd(shape=x.shape[1:])
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
return norm_x, rms
def get_network_builder(name):
"""

View File

@@ -1,5 +1,4 @@
import baselines.common.tf_util as U
import tensorflow as tf
import numpy as np
try:
from mpi4py import MPI
@@ -59,45 +58,3 @@ class MpiAdam(object):
thetaroot = np.empty_like(thetalocal)
self.comm.Bcast(thetaroot, root=0)
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
@U.in_session
def test_MpiAdam():
np.random.seed(0)
tf.set_random_seed(0)
a = tf.Variable(np.random.randn(3).astype('float32'))
b = tf.Variable(np.random.randn(2,5).astype('float32'))
loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))
stepsize = 1e-2
update_op = tf.train.AdamOptimizer(stepsize).minimize(loss)
do_update = U.function([], loss, updates=[update_op])
tf.get_default_session().run(tf.global_variables_initializer())
losslist_ref = []
for i in range(10):
l = do_update()
print(i, l)
losslist_ref.append(l)
tf.set_random_seed(0)
tf.get_default_session().run(tf.global_variables_initializer())
var_list = [a,b]
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
adam = MpiAdam(var_list)
losslist_test = []
for i in range(10):
l,g = lossandgrad()
adam.update(g, stepsize)
print(i,l)
losslist_test.append(l)
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
if __name__ == '__main__':
test_MpiAdam()

View File

@@ -1,54 +1,46 @@
import numpy as np
import tensorflow as tf
from baselines.common import tf_util as U
from baselines.common.tests.test_with_mpi import with_mpi
from baselines import logger
try:
from mpi4py import MPI
except ImportError:
MPI = None
class MpiAdamOptimizer(tf.train.AdamOptimizer):
class MpiAdamOptimizer(tf.Module):
"""Adam optimizer that averages gradients across mpi processes."""
def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs):
def __init__(self, comm, var_list):
self.var_list = var_list
self.comm = comm
self.grad_clip = grad_clip
self.mpi_rank_weight = mpi_rank_weight
tf.train.AdamOptimizer.__init__(self, **kwargs)
def compute_gradients(self, loss, var_list, **kwargs):
grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) * self.mpi_rank_weight
shapes = [v.shape.as_list() for g, v in grads_and_vars]
sizes = [int(np.prod(s)) for s in shapes]
self.beta1 = 0.9
self.beta2 = 0.999
self.epsilon = 1e-08
self.t = tf.Variable(0, name='step', dtype=tf.int32)
var_shapes = [v.shape.as_list() for v in var_list]
self.var_sizes = [int(np.prod(s)) for s in var_shapes]
self.flat_var_size = sum(self.var_sizes)
self.m = tf.Variable(np.zeros(self.flat_var_size, 'float32'))
self.v = tf.Variable(np.zeros(self.flat_var_size, 'float32'))
total_weight = np.zeros(1, np.float32)
self.comm.Allreduce(np.array([self.mpi_rank_weight], dtype=np.float32), total_weight, op=MPI.SUM)
total_weight = total_weight[0]
def apply_gradients(self, flat_grad, lr):
buf = np.zeros(self.flat_var_size, np.float32)
self.comm.Allreduce(flat_grad.numpy(), buf, op=MPI.SUM)
avg_flat_grad = np.divide(buf, float(self.comm.Get_size()))
self._apply_gradients(tf.constant(avg_flat_grad), lr)
if self.t.numpy() % 100 == 0:
check_synced(tf.reduce_sum(self.var_list[0]).numpy())
buf = np.zeros(sum(sizes), np.float32)
countholder = [0] # Counts how many times _collect_grads has been called
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
def _collect_grads(flat_grad, np_stat):
if self.grad_clip is not None:
gradnorm = np.linalg.norm(flat_grad)
if gradnorm > 1:
flat_grad /= gradnorm
logger.logkv_mean('gradnorm', gradnorm)
logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
np.divide(buf, float(total_weight), out=buf)
if countholder[0] % 100 == 0:
check_synced(np_stat, self.comm)
countholder[0] += 1
return buf
@tf.function
def _apply_gradients(self, avg_flat_grad, lr):
self.t.assign_add(1)
t = tf.cast(self.t, tf.float32)
a = lr * tf.math.sqrt(1 - tf.math.pow(self.beta2, t)) / (1 - tf.math.pow(self.beta1, t))
self.m.assign(self.beta1 * self.m + (1 - self.beta1) * avg_flat_grad)
self.v.assign(self.beta2 * self.v + (1 - self.beta2) * tf.math.square(avg_flat_grad))
flat_step = (- a) * self.m / (tf.math.sqrt(self.v) + self.epsilon)
var_steps = tf.split(flat_step, self.var_sizes, axis=0)
for var_step, var in zip(var_steps, self.var_list):
var.assign_add(tf.reshape(var_step, var.shape))
avg_flat_grad = tf.py_func(_collect_grads, [flat_grad, stat], tf.float32)
avg_flat_grad.set_shape(flat_grad.shape)
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
for g, (_, v) in zip(avg_grads, grads_and_vars)]
return avg_grads_and_vars
def check_synced(localval, comm=None):
"""
@@ -66,25 +58,3 @@ def check_synced(localval, comm=None):
if comm.rank == 0:
assert all(val==vals[0] for val in vals[1:]),\
'MpiAdamOptimizer detected that different workers have different weights: {}'.format(vals)
@with_mpi(timeout=5)
def test_nonfreeze():
np.random.seed(0)
tf.set_random_seed(0)
a = tf.Variable(np.random.randn(3).astype('float32'))
b = tf.Variable(np.random.randn(2,5).astype('float32'))
loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))
stepsize = 1e-2
# for some reason the session config with inter_op_parallelism_threads was causing
# nested sess.run calls to freeze
config = tf.ConfigProto(inter_op_parallelism_threads=1)
sess = U.get_session(config=config)
update_op = MpiAdamOptimizer(comm=MPI.COMM_WORLD, learning_rate=stepsize).minimize(loss)
sess.run(tf.global_variables_initializer())
losslist_ref = []
for i in range(100):
l,_ = sess.run([loss, update_op])
print(i, l)
losslist_ref.append(l)

View File

@@ -3,110 +3,54 @@ try:
except ImportError:
MPI = None
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
import tensorflow as tf, numpy as np
class RunningMeanStd(object):
class RunningMeanStd(tf.Module):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-2, shape=()):
def __init__(self, epsilon=1e-2, shape=(), default_clip_range=np.inf):
self._sum = tf.get_variable(
self._sum = tf.Variable(
initial_value=np.zeros(shape=shape, dtype=np.float64),
dtype=tf.float64,
shape=shape,
initializer=tf.constant_initializer(0.0),
name="runningsum", trainable=False)
self._sumsq = tf.get_variable(
self._sumsq = tf.Variable(
initial_value=np.full(shape=shape, fill_value=epsilon, dtype=np.float64),
dtype=tf.float64,
shape=shape,
initializer=tf.constant_initializer(epsilon),
name="runningsumsq", trainable=False)
self._count = tf.get_variable(
self._count = tf.Variable(
initial_value=epsilon,
dtype=tf.float64,
shape=(),
initializer=tf.constant_initializer(epsilon),
name="count", trainable=False)
self.shape = shape
self.mean = tf.to_float(self._sum / self._count)
self.std = tf.sqrt( tf.maximum( tf.to_float(self._sumsq / self._count) - tf.square(self.mean) , 1e-2 ))
newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
updates=[tf.assign_add(self._sum, newsum),
tf.assign_add(self._sumsq, newsumsq),
tf.assign_add(self._count, newcount)])
self.epsilon = epsilon
self.default_clip_range = default_clip_range
def update(self, x):
x = x.astype('float64')
n = int(np.prod(self.shape))
totalvec = np.zeros(n*2+1, 'float64')
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
totalvec = np.zeros(n*2+1, 'float64')
if MPI is not None:
# totalvec = np.zeros(n*2+1, 'float64')
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
# else:
# totalvec = addvec
self._sum.assign_add(totalvec[0:n].reshape(self.shape))
self._sumsq.assign_add(totalvec[n:2*n].reshape(self.shape))
self._count.assign_add(totalvec[2*n])
@U.in_session
def test_runningmeanstd():
for (x1, x2, x3) in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
]:
@property
def mean(self):
return tf.cast(self._sum / self._count, tf.float32)
rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
U.initialize()
@property
def std(self):
return tf.sqrt(tf.maximum(tf.cast(self._sumsq / self._count, tf.float32) - tf.square(self.mean), self.epsilon))
x = np.concatenate([x1, x2, x3], axis=0)
ms1 = [x.mean(axis=0), x.std(axis=0)]
rms.update(x1)
rms.update(x2)
rms.update(x3)
ms2 = [rms.mean.eval(), rms.std.eval()]
def normalize(self, v, clip_range=None):
if clip_range is None:
clip_range = self.default_clip_range
return tf.clip_by_value((v - self.mean) / self.std, -clip_range, clip_range)
assert np.allclose(ms1, ms2)
@U.in_session
def test_dist():
np.random.seed(0)
p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1))
q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1))
# p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
# q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))
comm = MPI.COMM_WORLD
assert comm.Get_size()==2
if comm.Get_rank()==0:
x1,x2,x3 = p1,p2,p3
elif comm.Get_rank()==1:
x1,x2,x3 = q1,q2,q3
else:
assert False
rms = RunningMeanStd(epsilon=0.0, shape=(1,))
U.initialize()
rms.update(x1)
rms.update(x2)
rms.update(x3)
bigvec = np.concatenate([p1,p2,p3,q1,q2,q3])
def checkallclose(x,y):
print(x,y)
return np.allclose(x,y)
assert checkallclose(
bigvec.mean(axis=0),
rms.mean.eval(),
)
assert checkallclose(
bigvec.std(axis=0),
rms.std.eval(),
)
if __name__ == "__main__":
# Run with mpirun -np 2 python <filename>
test_dist()
def denormalize(self, v):
return self.mean + v * self.std

View File

@@ -12,18 +12,16 @@ except ImportError:
MPI = None
def sync_from_root(sess, variables, comm=None):
def sync_from_root(variables, comm=None):
"""
Send the root node's parameters to every worker.
Arguments:
sess: the TensorFlow session.
variables: all parameter variables including optimizer's
"""
if comm is None: comm = MPI.COMM_WORLD
import tensorflow as tf
values = comm.bcast(sess.run(variables))
sess.run([tf.assign(var, val)
for (var, val) in zip(variables, values)])
values = comm.bcast([var.numpy() for var in variables])
for (var, val) in zip(variables, values):
var.assign(val)
def gpu_count():
"""

View File

@@ -1,101 +1,68 @@
import tensorflow as tf
from baselines.common import tf_util
from baselines.a2c.utils import fc
from baselines.common.distributions import make_pdtype
from baselines.common.input import observation_placeholder, encode_observation
from baselines.common.tf_util import adjust_shape
from baselines.common.mpi_running_mean_std import RunningMeanStd
from baselines.common.models import get_network_builder
import gym
class PolicyWithValue(object):
class PolicyWithValue(tf.Module):
"""
Encapsulates fields and methods for RL policy and value function estimation with shared parameters
"""
def __init__(self, env, observations, latent, estimate_q=False, vf_latent=None, sess=None, **tensors):
def __init__(self, ac_space, policy_network, value_network=None, estimate_q=False):
"""
Parameters:
----------
env RL environment
ac_space action space
observations tensorflow placeholder in which the observations will be fed
policy_network keras network for policy
latent latent state from which policy distribution parameters should be inferred
value_network keras network for value
vf_latent latent state from which value function should be inferred (if None, then latent is used)
sess tensorflow session to run calculations in (if None, default session is used)
**tensors tensorflow tensors for additional attributes such as state or mask
estimate_q q value or v value
"""
self.X = observations
self.state = tf.constant([])
self.policy_network = policy_network
self.value_network = value_network or policy_network
self.estimate_q = estimate_q
self.initial_state = None
self.__dict__.update(tensors)
vf_latent = vf_latent if vf_latent is not None else latent
vf_latent = tf.layers.flatten(vf_latent)
latent = tf.layers.flatten(latent)
# Based on the action space, will select what probability distribution type
self.pdtype = make_pdtype(env.action_space)
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
# Take an action
self.action = self.pd.sample()
# Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action)
self.sess = sess or tf.get_default_session()
self.pdtype = make_pdtype(policy_network.output_shape, ac_space, init_scale=0.01)
if estimate_q:
assert isinstance(env.action_space, gym.spaces.Discrete)
self.q = fc(vf_latent, 'q', env.action_space.n)
self.vf = self.q
assert isinstance(ac_space, gym.spaces.Discrete)
self.value_fc = fc(self.value_network.output_shape, 'q', ac_space.n)
else:
self.vf = fc(vf_latent, 'vf', 1)
self.vf = self.vf[:,0]
self.value_fc = fc(self.value_network.output_shape, 'vf', 1)
def _evaluate(self, variables, observation, **extra_feed):
sess = self.sess
feed_dict = {self.X: adjust_shape(self.X, observation)}
for inpt_name, data in extra_feed.items():
if inpt_name in self.__dict__.keys():
inpt = self.__dict__[inpt_name]
if isinstance(inpt, tf.Tensor) and inpt._op.type == 'Placeholder':
feed_dict[inpt] = adjust_shape(inpt, data)
return sess.run(variables, feed_dict)
def step(self, observation, **extra_feed):
@tf.function
def step(self, observation):
"""
Compute next action(s) given the observation(s)
Parameters:
----------
observation observation data (either single or a batch)
**extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
observation batched observation data
Returns:
-------
(action, value estimate, next state, negative log likelihood of the action under current policy parameters) tuple
"""
a, v, state, neglogp = self._evaluate([self.action, self.vf, self.state, self.neglogp], observation, **extra_feed)
if state.size == 0:
state = None
return a, v, state, neglogp
latent = self.policy_network(observation)
pd, pi = self.pdtype.pdfromlatent(latent)
action = pd.sample()
neglogp = pd.neglogp(action)
value_latent = self.value_network(observation)
vf = tf.squeeze(self.value_fc(value_latent), axis=1)
return action, vf, None, neglogp
def value(self, ob, *args, **kwargs):
@tf.function
def value(self, observation):
"""
Compute value estimate(s) given the observation(s)
@@ -104,83 +71,11 @@ class PolicyWithValue(object):
observation observation data (either single or a batch)
**extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
Returns:
-------
value estimate
"""
return self._evaluate(self.vf, ob, *args, **kwargs)
def save(self, save_path):
tf_util.save_state(save_path, sess=self.sess)
def load(self, load_path):
tf_util.load_state(load_path, sess=self.sess)
def build_policy(env, policy_network, value_network=None, normalize_observations=False, estimate_q=False, **policy_kwargs):
if isinstance(policy_network, str):
network_type = policy_network
policy_network = get_network_builder(network_type)(**policy_kwargs)
def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
ob_space = env.observation_space
X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space, batch_size=nbatch)
extra_tensors = {}
if normalize_observations and X.dtype == tf.float32:
encoded_x, rms = _normalize_clip_observation(X)
extra_tensors['rms'] = rms
else:
encoded_x = X
encoded_x = encode_observation(ob_space, encoded_x)
with tf.variable_scope('pi', reuse=tf.AUTO_REUSE):
policy_latent = policy_network(encoded_x)
if isinstance(policy_latent, tuple):
policy_latent, recurrent_tensors = policy_latent
if recurrent_tensors is not None:
# recurrent architecture, need a few more steps
nenv = nbatch // nsteps
assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps)
policy_latent, recurrent_tensors = policy_network(encoded_x, nenv)
extra_tensors.update(recurrent_tensors)
_v_net = value_network
if _v_net is None or _v_net == 'shared':
vf_latent = policy_latent
else:
if _v_net == 'copy':
_v_net = policy_network
else:
assert callable(_v_net)
with tf.variable_scope('vf', reuse=tf.AUTO_REUSE):
# TODO recurrent architectures are not supported with value_network=copy yet
vf_latent = _v_net(encoded_x)
policy = PolicyWithValue(
env=env,
observations=X,
latent=policy_latent,
vf_latent=vf_latent,
sess=sess,
estimate_q=estimate_q,
**extra_tensors
)
return policy
return policy_fn
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
rms = RunningMeanStd(shape=x.shape[1:])
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
return norm_x, rms
value_latent = self.value_network(observation)
result = tf.squeeze(self.value_fc(value_latent), axis=1)
return result

View File

@@ -1,6 +1,4 @@
import tensorflow as tf
import numpy as np
from baselines.common.tf_util import get_session
class RunningMeanStd(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
@@ -31,157 +29,3 @@ def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var,
new_count = tot_count
return new_mean, new_var, new_count
class TfRunningMeanStd(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
'''
TensorFlow variables-based implmentation of computing running mean and std
Benefit of this implementation is that it can be saved / loaded together with the tensorflow model
'''
def __init__(self, epsilon=1e-4, shape=(), scope=''):
sess = get_session()
self._new_mean = tf.placeholder(shape=shape, dtype=tf.float64)
self._new_var = tf.placeholder(shape=shape, dtype=tf.float64)
self._new_count = tf.placeholder(shape=(), dtype=tf.float64)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
self._mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
self._var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
self._count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64)
self.update_ops = tf.group([
self._var.assign(self._new_var),
self._mean.assign(self._new_mean),
self._count.assign(self._new_count)
])
sess.run(tf.variables_initializer([self._mean, self._var, self._count]))
self.sess = sess
self._set_mean_var_count()
def _set_mean_var_count(self):
self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count])
def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
new_mean, new_var, new_count = update_mean_var_count_from_moments(self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
self.sess.run(self.update_ops, feed_dict={
self._new_mean: new_mean,
self._new_var: new_var,
self._new_count: new_count
})
self._set_mean_var_count()
def test_runningmeanstd():
for (x1, x2, x3) in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
]:
rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
x = np.concatenate([x1, x2, x3], axis=0)
ms1 = [x.mean(axis=0), x.var(axis=0)]
rms.update(x1)
rms.update(x2)
rms.update(x3)
ms2 = [rms.mean, rms.var]
np.testing.assert_allclose(ms1, ms2)
def test_tf_runningmeanstd():
for (x1, x2, x3) in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
]:
rms = TfRunningMeanStd(epsilon=0.0, shape=x1.shape[1:], scope='running_mean_std' + str(np.random.randint(0, 128)))
x = np.concatenate([x1, x2, x3], axis=0)
ms1 = [x.mean(axis=0), x.var(axis=0)]
rms.update(x1)
rms.update(x2)
rms.update(x3)
ms2 = [rms.mean, rms.var]
np.testing.assert_allclose(ms1, ms2)
def profile_tf_runningmeanstd():
import time
from baselines.common import tf_util
tf_util.get_session( config=tf.ConfigProto(
inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1,
allow_soft_placement=True
))
x = np.random.random((376,))
n_trials = 10000
rms = RunningMeanStd()
tfrms = TfRunningMeanStd()
tic1 = time.time()
for _ in range(n_trials):
rms.update(x)
tic2 = time.time()
for _ in range(n_trials):
tfrms.update(x)
tic3 = time.time()
print('rms update time ({} trials): {} s'.format(n_trials, tic2 - tic1))
print('tfrms update time ({} trials): {} s'.format(n_trials, tic3 - tic2))
tic1 = time.time()
for _ in range(n_trials):
z1 = rms.mean
tic2 = time.time()
for _ in range(n_trials):
z2 = tfrms.mean
assert z1 == z2
tic3 = time.time()
print('rms get mean time ({} trials): {} s'.format(n_trials, tic2 - tic1))
print('tfrms get mean time ({} trials): {} s'.format(n_trials, tic3 - tic2))
'''
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #pylint: disable=E1101
run_metadata = tf.RunMetadata()
profile_opts = dict(options=options, run_metadata=run_metadata)
from tensorflow.python.client import timeline
fetched_timeline = timeline.Timeline(run_metadata.step_stats) #pylint: disable=E1101
chrome_trace = fetched_timeline.generate_chrome_trace_format()
outfile = '/tmp/timeline.json'
with open(outfile, 'wt') as f:
f.write(chrome_trace)
print('Successfully saved profile to {}. Exiting.'.format(outfile))
exit(0)
'''
if __name__ == '__main__':
profile_tf_runningmeanstd()

View File

@@ -1,29 +0,0 @@
from baselines.common import mpi_util
from baselines import logger
from baselines.common.tests.test_with_mpi import with_mpi
try:
from mpi4py import MPI
except ImportError:
MPI = None
@with_mpi()
def test_mpi_weighted_mean():
comm = MPI.COMM_WORLD
with logger.scoped_configure(comm=comm):
if comm.rank == 0:
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
elif comm.rank == 1:
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
else:
raise NotImplementedError
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
if comm.rank == 0:
assert d == correctval, '{} != {}'.format(d, correctval)
for name, (val, count) in name2valcount.items():
for _ in range(count):
logger.logkv_mean(name, val)
d2 = logger.dumpkvs()
if comm.rank == 0:
assert d2 == correctval

View File

@@ -1,2 +0,0 @@
import os, pytest
mark_slow = pytest.mark.skipif(not os.getenv('RUNSLOW'), reason='slow')

View File

@@ -1,43 +0,0 @@
import numpy as np
from gym import Env
from gym.spaces import Discrete
class FixedSequenceEnv(Env):
def __init__(
self,
n_actions=10,
episode_len=100
):
self.action_space = Discrete(n_actions)
self.observation_space = Discrete(1)
self.np_random = np.random.RandomState(0)
self.episode_len = episode_len
self.sequence = [self.np_random.randint(0, self.action_space.n)
for _ in range(self.episode_len)]
self.time = 0
def reset(self):
self.time = 0
return 0
def step(self, actions):
rew = self._get_reward(actions)
self._choose_next_state()
done = False
if self.episode_len and self.time >= self.episode_len:
done = True
return 0, rew, done, {}
def seed(self, seed=None):
self.np_random.seed(seed)
def _choose_next_state(self):
self.time += 1
def _get_reward(self, actions):
return 1 if actions == self.sequence[self.time] else 0

View File

@@ -1,90 +0,0 @@
import numpy as np
from abc import abstractmethod
from gym import Env
from gym.spaces import MultiDiscrete, Discrete, Box
from collections import deque
class IdentityEnv(Env):
def __init__(
self,
episode_len=None,
delay=0,
zero_first_rewards=True
):
self.observation_space = self.action_space
self.episode_len = episode_len
self.time = 0
self.delay = delay
self.zero_first_rewards = zero_first_rewards
self.q = deque(maxlen=delay+1)
def reset(self):
self.q.clear()
for _ in range(self.delay + 1):
self.q.append(self.action_space.sample())
self.time = 0
return self.q[-1]
def step(self, actions):
rew = self._get_reward(self.q.popleft(), actions)
if self.zero_first_rewards and self.time < self.delay:
rew = 0
self.q.append(self.action_space.sample())
self.time += 1
done = self.episode_len is not None and self.time >= self.episode_len
return self.q[-1], rew, done, {}
def seed(self, seed=None):
self.action_space.seed(seed)
@abstractmethod
def _get_reward(self, state, actions):
raise NotImplementedError
class DiscreteIdentityEnv(IdentityEnv):
def __init__(
self,
dim,
episode_len=None,
delay=0,
zero_first_rewards=True
):
self.action_space = Discrete(dim)
super().__init__(episode_len=episode_len, delay=delay, zero_first_rewards=zero_first_rewards)
def _get_reward(self, state, actions):
return 1 if state == actions else 0
class MultiDiscreteIdentityEnv(IdentityEnv):
def __init__(
self,
dims,
episode_len=None,
delay=0,
):
self.action_space = MultiDiscrete(dims)
super().__init__(episode_len=episode_len, delay=delay)
def _get_reward(self, state, actions):
return 1 if all(state == actions) else 0
class BoxIdentityEnv(IdentityEnv):
def __init__(
self,
shape,
episode_len=None,
):
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
super().__init__(episode_len=episode_len)
def _get_reward(self, state, actions):
diff = actions - state
diff = diff[:]
return -0.5 * np.dot(diff, diff)

View File

@@ -1,36 +0,0 @@
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv
def test_discrete_nodelay():
nsteps = 100
eplen = 50
env = DiscreteIdentityEnv(10, episode_len=eplen)
ob = env.reset()
for t in range(nsteps):
action = env.action_space.sample()
next_ob, rew, done, info = env.step(action)
assert rew == (1 if action == ob else 0)
if (t + 1) % eplen == 0:
assert done
next_ob = env.reset()
else:
assert not done
ob = next_ob
def test_discrete_delay1():
eplen = 50
env = DiscreteIdentityEnv(10, episode_len=eplen, delay=1)
ob = env.reset()
prev_ob = None
for t in range(eplen):
action = env.action_space.sample()
next_ob, rew, done, info = env.step(action)
if t > 0:
assert rew == (1 if action == prev_ob else 0)
else:
assert rew == 0
prev_ob = ob
ob = next_ob
if t < eplen - 1:
assert not done
assert done

View File

@@ -1,71 +0,0 @@
import os.path as osp
import numpy as np
import tempfile
from gym import Env
from gym.spaces import Discrete, Box
class MnistEnv(Env):
def __init__(
self,
episode_len=None,
no_images=None
):
import filelock
from tensorflow.examples.tutorials.mnist import input_data
# we could use temporary directory for this with a context manager and
# TemporaryDirecotry, but then each test that uses mnist would re-download the data
# this way the data is not cleaned up, but we only download it once per machine
mnist_path = osp.join(tempfile.gettempdir(), 'MNIST_data')
with filelock.FileLock(mnist_path + '.lock'):
self.mnist = input_data.read_data_sets(mnist_path)
self.np_random = np.random.RandomState()
self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
self.action_space = Discrete(10)
self.episode_len = episode_len
self.time = 0
self.no_images = no_images
self.train_mode()
self.reset()
def reset(self):
self._choose_next_state()
self.time = 0
return self.state[0]
def step(self, actions):
rew = self._get_reward(actions)
self._choose_next_state()
done = False
if self.episode_len and self.time >= self.episode_len:
rew = 0
done = True
return self.state[0], rew, done, {}
def seed(self, seed=None):
self.np_random.seed(seed)
def train_mode(self):
self.dataset = self.mnist.train
def test_mode(self):
self.dataset = self.mnist.test
def _choose_next_state(self):
max_index = (self.no_images if self.no_images is not None else self.dataset.num_examples) - 1
index = self.np_random.randint(0, max_index)
image = self.dataset.images[index].reshape(28,28,1)*255
label = self.dataset.labels[index]
self.state = (image, label)
self.time += 1
def _get_reward(self, actions):
return 1 if self.state[1] == actions else 0

View File

@@ -1,45 +0,0 @@
import pytest
import gym
from baselines.run import get_learn_function
from baselines.common.tests.util import reward_per_episode_test
from baselines.common.tests import mark_slow
common_kwargs = dict(
total_timesteps=30000,
network='mlp',
gamma=1.0,
seed=0,
)
learn_kwargs = {
'a2c' : dict(nsteps=32, value_network='copy', lr=0.05),
'acer': dict(value_network='copy'),
'acktr': dict(nsteps=32, value_network='copy', is_async=False),
'deepq': dict(total_timesteps=20000),
'ppo2': dict(value_network='copy'),
'trpo_mpi': {}
}
@mark_slow
@pytest.mark.parametrize("alg", learn_kwargs.keys())
def test_cartpole(alg):
'''
Test if the algorithm (with an mlp policy)
can learn to balance the cartpole
'''
kwargs = common_kwargs.copy()
kwargs.update(learn_kwargs[alg])
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
def env_fn():
env = gym.make('CartPole-v0')
env.seed(0)
return env
reward_per_episode_test(env_fn, learn_fn, 100)
if __name__ == '__main__':
test_cartpole('acer')

View File

@@ -1,48 +0,0 @@
import pytest
try:
import mujoco_py
_mujoco_present = True
except BaseException:
mujoco_py = None
_mujoco_present = False
@pytest.mark.skipif(
not _mujoco_present,
reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library'
)
def test_lstm_example():
import tensorflow as tf
from baselines.common import policies, models, cmd_util
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
# create vectorized environment
venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)])
with tf.Session() as sess:
# build policy based on lstm network with 128 units
policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1)
# initialize tensorflow variables
sess.run(tf.global_variables_initializer())
# prepare environment variables
ob = venv.reset()
state = policy.initial_state
done = [False]
step_counter = 0
# run a single episode until the end (i.e. until done)
while True:
action, _, state, _ = policy.step(ob, S=state, M=done)
ob, reward, done, _ = venv.step(action)
step_counter += 1
if done:
break
assert step_counter > 5

View File

@@ -1,27 +0,0 @@
import pytest
import gym
import tensorflow as tf
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.run import get_learn_function
from baselines.common.tf_util import make_session
algos = ['a2c', 'acer', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
@pytest.mark.parametrize('algo', algos)
def test_env_after_learn(algo):
def make_env():
# acktr requires too much RAM, fails on travis
env = gym.make('CartPole-v1' if algo == 'acktr' else 'PongNoFrameskip-v4')
return env
make_session(make_default=True, graph=tf.Graph())
env = SubprocVecEnv([make_env])
learn = get_learn_function(algo)
# Commenting out the following line resolves the issue, though crash happens at env.reset().
learn(network='mlp', env=env, total_timesteps=0, load_path=None, seed=None)
env.reset()
env.close()

View File

@@ -1,40 +0,0 @@
import pytest
import gym
from baselines.run import get_learn_function
from baselines.common.tests.util import reward_per_episode_test
from baselines.common.tests import mark_slow
pytest.importorskip('mujoco_py')
common_kwargs = dict(
network='mlp',
seed=0,
)
learn_kwargs = {
'her': dict(total_timesteps=2000)
}
@mark_slow
@pytest.mark.parametrize("alg", learn_kwargs.keys())
def test_fetchreach(alg):
'''
Test if the algorithm (with an mlp policy)
can learn the FetchReach task
'''
kwargs = common_kwargs.copy()
kwargs.update(learn_kwargs[alg])
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
def env_fn():
env = gym.make('FetchReach-v1')
env.seed(0)
return env
reward_per_episode_test(env_fn, learn_fn, -15)
if __name__ == '__main__':
test_fetchreach('her')

View File

@@ -1,52 +0,0 @@
import pytest
from baselines.common.tests.envs.fixed_sequence_env import FixedSequenceEnv
from baselines.common.tests.util import simple_test
from baselines.run import get_learn_function
from baselines.common.tests import mark_slow
common_kwargs = dict(
seed=0,
total_timesteps=50000,
)
learn_kwargs = {
'a2c': {},
'ppo2': dict(nsteps=10, ent_coef=0.0, nminibatches=1),
# TODO enable sequential models for trpo_mpi (proper handling of nbatch and nsteps)
# github issue: https://github.com/openai/baselines/issues/188
# 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001)
}
alg_list = learn_kwargs.keys()
rnn_list = ['lstm']
@mark_slow
@pytest.mark.parametrize("alg", alg_list)
@pytest.mark.parametrize("rnn", rnn_list)
def test_fixed_sequence(alg, rnn):
'''
Test if the algorithm (with a given policy)
can learn an identity transformation (i.e. return observation as an action)
'''
kwargs = learn_kwargs[alg]
kwargs.update(common_kwargs)
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
learn = lambda e: get_learn_function(alg)(
env=e,
network=rnn,
**kwargs
)
simple_test(env_fn, learn, 0.7)
if __name__ == '__main__':
test_fixed_sequence('ppo2', 'lstm')

View File

@@ -1,76 +0,0 @@
import pytest
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv
from baselines.run import get_learn_function
from baselines.common.tests.util import simple_test
from baselines.common.tests import mark_slow
common_kwargs = dict(
total_timesteps=30000,
network='mlp',
gamma=0.9,
seed=0,
)
learn_kwargs = {
'a2c' : {},
'acktr': {},
'deepq': {},
'ddpg': dict(layer_norm=True),
'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0),
'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01)
}
algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi']
algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
@mark_slow
@pytest.mark.parametrize("alg", algos_disc)
def test_discrete_identity(alg):
'''
Test if the algorithm (with an mlp policy)
can learn an identity transformation (i.e. return observation as an action)
'''
kwargs = learn_kwargs[alg]
kwargs.update(common_kwargs)
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100)
simple_test(env_fn, learn_fn, 0.9)
@mark_slow
@pytest.mark.parametrize("alg", algos_multidisc)
def test_multidiscrete_identity(alg):
'''
Test if the algorithm (with an mlp policy)
can learn an identity transformation (i.e. return observation as an action)
'''
kwargs = learn_kwargs[alg]
kwargs.update(common_kwargs)
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100)
simple_test(env_fn, learn_fn, 0.9)
@mark_slow
@pytest.mark.parametrize("alg", algos_cont)
def test_continuous_identity(alg):
'''
Test if the algorithm (with an mlp policy)
can learn an identity transformation (i.e. return observation as an action)
to a required precision
'''
kwargs = learn_kwargs[alg]
kwargs.update(common_kwargs)
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
env_fn = lambda: BoxIdentityEnv((1,), episode_len=100)
simple_test(env_fn, learn_fn, -0.1)
if __name__ == '__main__':
test_multidiscrete_identity('acktr')

View File

@@ -1,49 +0,0 @@
import pytest
# from baselines.acer import acer_simple as acer
from baselines.common.tests.envs.mnist_env import MnistEnv
from baselines.common.tests.util import simple_test
from baselines.run import get_learn_function
from baselines.common.tests import mark_slow
# TODO investigate a2c and ppo2 failures - is it due to bad hyperparameters for this problem?
# GitHub issue https://github.com/openai/baselines/issues/189
common_kwargs = {
'seed': 0,
'network':'cnn',
'gamma':0.9,
'pad':'SAME'
}
learn_args = {
'a2c': dict(total_timesteps=50000),
'acer': dict(total_timesteps=20000),
'deepq': dict(total_timesteps=5000),
'acktr': dict(total_timesteps=30000),
'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0),
'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001)
}
#tests pass, but are too slow on travis. Same algorithms are covered
# by other tests with less compute-hungry nn's and by benchmarks
@pytest.mark.skip
@mark_slow
@pytest.mark.parametrize("alg", learn_args.keys())
def test_mnist(alg):
'''
Test if the algorithm can learn to classify MNIST digits.
Uses CNN policy.
'''
learn_kwargs = learn_args[alg]
learn_kwargs.update(common_kwargs)
learn = get_learn_function(alg)
learn_fn = lambda e: learn(env=e, **learn_kwargs)
env_fn = lambda: MnistEnv(episode_len=100)
simple_test(env_fn, learn_fn, 0.6)
if __name__ == '__main__':
test_mnist('acer')

View File

@@ -1,17 +0,0 @@
# smoke tests of plot_util
from baselines.common import plot_util as pu
from baselines.common.tests.util import smoketest
def test_plot_util():
nruns = 4
logdirs = [smoketest('--alg=ppo2 --env=CartPole-v0 --num_timesteps=10000') for _ in range(nruns)]
data = pu.load_results(logdirs)
assert len(data) == 4
_, axes = pu.plot_results(data[:1]); assert len(axes) == 1
_, axes = pu.plot_results(data, tiling='vertical'); assert axes.shape==(4,1)
_, axes = pu.plot_results(data, tiling='horizontal'); assert axes.shape==(1,4)
_, axes = pu.plot_results(data, tiling='symmetric'); assert axes.shape==(2,2)
_, axes = pu.plot_results(data, split_fn=lambda _: ''); assert len(axes) == 1

View File

@@ -1,26 +0,0 @@
import numpy as np
from baselines.common.schedules import ConstantSchedule, PiecewiseSchedule
def test_piecewise_schedule():
ps = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], outside_value=500)
assert np.isclose(ps.value(-10), 500)
assert np.isclose(ps.value(0), 150)
assert np.isclose(ps.value(5), 200)
assert np.isclose(ps.value(9), 80)
assert np.isclose(ps.value(50), 50)
assert np.isclose(ps.value(80), 50)
assert np.isclose(ps.value(150), 0)
assert np.isclose(ps.value(175), -25)
assert np.isclose(ps.value(201), 500)
assert np.isclose(ps.value(500), 500)
assert np.isclose(ps.value(200 - 1e-10), -50)
def test_constant_schedule():
cs = ConstantSchedule(5)
for i in range(-100, 100):
assert np.isclose(cs.value(i), 5)

View File

@@ -1,103 +0,0 @@
import numpy as np
from baselines.common.segment_tree import SumSegmentTree, MinSegmentTree
def test_tree_set():
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[3] = 3.0
assert np.isclose(tree.sum(), 4.0)
assert np.isclose(tree.sum(0, 2), 0.0)
assert np.isclose(tree.sum(0, 3), 1.0)
assert np.isclose(tree.sum(2, 3), 1.0)
assert np.isclose(tree.sum(2, -1), 1.0)
assert np.isclose(tree.sum(2, 4), 4.0)
def test_tree_set_overlap():
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[2] = 3.0
assert np.isclose(tree.sum(), 3.0)
assert np.isclose(tree.sum(2, 3), 3.0)
assert np.isclose(tree.sum(2, -1), 3.0)
assert np.isclose(tree.sum(2, 4), 3.0)
assert np.isclose(tree.sum(1, 2), 0.0)
def test_prefixsum_idx():
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[3] = 3.0
assert tree.find_prefixsum_idx(0.0) == 2
assert tree.find_prefixsum_idx(0.5) == 2
assert tree.find_prefixsum_idx(0.99) == 2
assert tree.find_prefixsum_idx(1.01) == 3
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(4.00) == 3
def test_prefixsum_idx2():
tree = SumSegmentTree(4)
tree[0] = 0.5
tree[1] = 1.0
tree[2] = 1.0
tree[3] = 3.0
assert tree.find_prefixsum_idx(0.00) == 0
assert tree.find_prefixsum_idx(0.55) == 1
assert tree.find_prefixsum_idx(0.99) == 1
assert tree.find_prefixsum_idx(1.51) == 2
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(5.50) == 3
def test_max_interval_tree():
tree = MinSegmentTree(4)
tree[0] = 1.0
tree[2] = 0.5
tree[3] = 3.0
assert np.isclose(tree.min(), 0.5)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.5)
assert np.isclose(tree.min(0, -1), 0.5)
assert np.isclose(tree.min(2, 4), 0.5)
assert np.isclose(tree.min(3, 4), 3.0)
tree[2] = 0.7
assert np.isclose(tree.min(), 0.7)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.7)
assert np.isclose(tree.min(0, -1), 0.7)
assert np.isclose(tree.min(2, 4), 0.7)
assert np.isclose(tree.min(3, 4), 3.0)
tree[2] = 4.0
assert np.isclose(tree.min(), 1.0)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 1.0)
assert np.isclose(tree.min(0, -1), 1.0)
assert np.isclose(tree.min(2, 4), 3.0)
assert np.isclose(tree.min(2, 3), 4.0)
assert np.isclose(tree.min(2, -1), 4.0)
assert np.isclose(tree.min(3, 4), 3.0)
if __name__ == '__main__':
test_tree_set()
test_tree_set_overlap()
test_prefixsum_idx()
test_prefixsum_idx2()
test_max_interval_tree()

View File

@@ -1,139 +0,0 @@
import os
import gym
import tempfile
import pytest
import tensorflow as tf
import numpy as np
from baselines.common.tests.envs.mnist_env import MnistEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.run import get_learn_function
from baselines.common.tf_util import make_session, get_session
from functools import partial
learn_kwargs = {
'deepq': {},
'a2c': {},
'acktr': {},
'acer': {},
'ppo2': {'nminibatches': 1, 'nsteps': 10},
'trpo_mpi': {},
}
network_kwargs = {
'mlp': {},
'cnn': {'pad': 'SAME'},
'lstm': {},
'cnn_lnlstm': {'pad': 'SAME'}
}
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
@pytest.mark.parametrize("network_fn", network_kwargs.keys())
def test_serialization(learn_fn, network_fn):
'''
Test if the trained model can be serialized
'''
if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/660
return
def make_env():
env = MnistEnv(episode_len=100)
env.seed(10)
return env
env = DummyVecEnv([make_env])
ob = env.reset().copy()
learn = get_learn_function(learn_fn)
kwargs = {}
kwargs.update(network_kwargs[network_fn])
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
with tempfile.TemporaryDirectory() as td:
model_path = os.path.join(td, 'serialization_test_model')
with tf.Graph().as_default(), make_session().as_default():
model = learn(total_timesteps=100)
model.save(model_path)
mean1, std1 = _get_action_stats(model, ob)
variables_dict1 = _serialize_variables()
with tf.Graph().as_default(), make_session().as_default():
model = learn(total_timesteps=0, load_path=model_path)
mean2, std2 = _get_action_stats(model, ob)
variables_dict2 = _serialize_variables()
for k, v in variables_dict1.items():
np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
err_msg='saved and loaded variable {} value mismatch'.format(k))
np.testing.assert_allclose(mean1, mean2, atol=0.5)
np.testing.assert_allclose(std1, std2, atol=0.5)
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
@pytest.mark.parametrize("network_fn", ['mlp'])
def test_coexistence(learn_fn, network_fn):
'''
Test if more than one model can exist at a time
'''
if learn_fn == 'deepq':
# TODO enable multiple DQN models to be useable at the same time
# github issue https://github.com/openai/baselines/issues/656
return
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/660
return
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
learn = get_learn_function(learn_fn)
kwargs = {}
kwargs.update(network_kwargs[network_fn])
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
make_session(make_default=True, graph=tf.Graph())
model1 = learn(seed=1)
make_session(make_default=True, graph=tf.Graph())
model2 = learn(seed=2)
model1.step(env.observation_space.sample())
model2.step(env.observation_space.sample())
def _serialize_variables():
sess = get_session()
variables = tf.trainable_variables()
values = sess.run(variables)
return {var.name: value for var, value in zip(variables, values)}
def _get_action_stats(model, ob):
ntrials = 1000
if model.initial_state is None or model.initial_state == []:
actions = np.array([model.step(ob)[0] for _ in range(ntrials)])
else:
actions = np.array([model.step(ob, S=model.initial_state, M=[False])[0] for _ in range(ntrials)])
mean = np.mean(actions, axis=0)
std = np.std(actions, axis=0)
return mean, std

View File

@@ -1,42 +0,0 @@
# tests for tf_util
import tensorflow as tf
from baselines.common.tf_util import (
function,
initialize,
single_threaded_session
)
def test_function():
with tf.Graph().as_default():
x = tf.placeholder(tf.int32, (), name="x")
y = tf.placeholder(tf.int32, (), name="y")
z = 3 * x + 2 * y
lin = function([x, y], z, givens={y: 0})
with single_threaded_session():
initialize()
assert lin(2) == 6
assert lin(x=3) == 9
assert lin(2, 2) == 10
assert lin(x=2, y=3) == 12
def test_multikwargs():
with tf.Graph().as_default():
x = tf.placeholder(tf.int32, (), name="x")
with tf.variable_scope("other"):
x2 = tf.placeholder(tf.int32, (), name="x")
z = 3 * x + 2 * x2
lin = function([x, x2], z, givens={x2: 0})
with single_threaded_session():
initialize()
assert lin(2) == 6
assert lin(2, 2) == 10
if __name__ == '__main__':
test_function()
test_multikwargs()

View File

@@ -1,92 +0,0 @@
import tensorflow as tf
import numpy as np
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
N_TRIALS = 10000
N_EPISODES = 100
_sess_config = tf.ConfigProto(
allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1
)
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
def seeded_env_fn():
env = env_fn()
env.seed(0)
return env
np.random.seed(0)
env = DummyVecEnv([seeded_env_fn])
with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default():
tf.set_random_seed(0)
model = learn_fn(env)
sum_rew = 0
done = True
for i in range(n_trials):
if done:
obs = env.reset()
state = model.initial_state
if state is not None:
a, v, state, _ = model.step(obs, S=state, M=[False])
else:
a, v, _, _ = model.step(obs)
obs, rew, done, _ = env.step(a)
sum_rew += float(rew)
print("Reward in {} trials is {}".format(n_trials, sum_rew))
assert sum_rew > min_reward_fraction * n_trials, \
'sum of rewards {} is less than {} of the total number of trials {}'.format(sum_rew, min_reward_fraction, n_trials)
def reward_per_episode_test(env_fn, learn_fn, min_avg_reward, n_trials=N_EPISODES):
env = DummyVecEnv([env_fn])
with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default():
model = learn_fn(env)
N_TRIALS = 100
observations, actions, rewards = rollout(env, model, N_TRIALS)
rewards = [sum(r) for r in rewards]
avg_rew = sum(rewards) / N_TRIALS
print("Average reward in {} episodes is {}".format(n_trials, avg_rew))
assert avg_rew > min_avg_reward, \
'average reward in {} episodes ({}) is less than {}'.format(n_trials, avg_rew, min_avg_reward)
def rollout(env, model, n_trials):
rewards = []
actions = []
observations = []
for i in range(n_trials):
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
episode_rew = []
episode_actions = []
episode_obs = []
while True:
if state is not None:
a, v, state, _ = model.step(obs, S=state, M=[False])
else:
a,v, _, _ = model.step(obs)
obs, rew, done, _ = env.step(a)
episode_rew.append(rew)
episode_actions.append(a)
episode_obs.append(obs)
if done:
break
rewards.append(episode_rew)
actions.append(episode_actions)
observations.append(episode_obs)
return observations, actions, rewards
def smoketest(argstr, **kwargs):
import tempfile
import subprocess
import os
argstr = 'python -m baselines.run ' + argstr
for key, value in kwargs:
argstr += ' --{}={}'.format(key, value)
tempdir = tempfile.mkdtemp()
env = os.environ.copy()
env['OPENAI_LOGDIR'] = tempdir
subprocess.run(argstr.split(' '), env=env)
return tempdir

View File

@@ -1,10 +1,6 @@
import numpy as np
import tensorflow as tf # pylint: ignore-module
import copy
import os
import functools
import collections
import multiprocessing
def switch(condition, then_expression, else_expression):
"""Switches between two operations depending on a scalar value (int or bool).
@@ -44,52 +40,6 @@ def huber_loss(x, delta=1.0):
delta * (tf.abs(x) - 0.5 * delta)
)
# ================================================================
# Global session
# ================================================================
def get_session(config=None):
"""Get default session or create one with a given config"""
sess = tf.get_default_session()
if sess is None:
sess = make_session(config=config, make_default=True)
return sess
def make_session(config=None, num_cpu=None, make_default=False, graph=None):
"""Returns a session that will use <num_cpu> CPU's only"""
if num_cpu is None:
num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
if config is None:
config = tf.ConfigProto(
allow_soft_placement=True,
inter_op_parallelism_threads=num_cpu,
intra_op_parallelism_threads=num_cpu)
config.gpu_options.allow_growth = True
if make_default:
return tf.InteractiveSession(config=config, graph=graph)
else:
return tf.Session(config=config, graph=graph)
def single_threaded_session():
"""Returns a session which will only use a single CPU"""
return make_session(num_cpu=1)
def in_session(f):
@functools.wraps(f)
def newfunc(*args, **kwargs):
with tf.Session():
f(*args, **kwargs)
return newfunc
ALREADY_INITIALIZED = set()
def initialize():
"""Initialize all the uninitialized variables in the global scope."""
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
get_session().run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
# ================================================================
# Model components
# ================================================================
@@ -130,87 +80,6 @@ def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME",
return tf.nn.conv2d(x, w, stride_shape, pad) + b
# ================================================================
# Theano-like Function
# ================================================================
def function(inputs, outputs, updates=None, givens=None):
"""Just like Theano function. Take a bunch of tensorflow placeholders and expressions
computed based on those placeholders and produces f(inputs) -> outputs. Function f takes
values to be fed to the input's placeholders and produces the values of the expressions
in outputs.
Input values can be passed in the same order as inputs or can be provided as kwargs based
on placeholder name (passed to constructor or accessible via placeholder.op.name).
Example:
x = tf.placeholder(tf.int32, (), name="x")
y = tf.placeholder(tf.int32, (), name="y")
z = 3 * x + 2 * y
lin = function([x, y], z, givens={y: 0})
with single_threaded_session():
initialize()
assert lin(2) == 6
assert lin(x=3) == 9
assert lin(2, 2) == 10
assert lin(x=2, y=3) == 12
Parameters
----------
inputs: [tf.placeholder, tf.constant, or object with make_feed_dict method]
list of input arguments
outputs: [tf.Variable] or tf.Variable
list of outputs or a single output to be returned from function. Returned
value will also have the same shape.
updates: [tf.Operation] or tf.Operation
list of update functions or single update function that will be run whenever
the function is called. The return is ignored.
"""
if isinstance(outputs, list):
return _Function(inputs, outputs, updates, givens=givens)
elif isinstance(outputs, (dict, collections.OrderedDict)):
f = _Function(inputs, outputs.values(), updates, givens=givens)
return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs)))
else:
f = _Function(inputs, [outputs], updates, givens=givens)
return lambda *args, **kwargs: f(*args, **kwargs)[0]
class _Function(object):
def __init__(self, inputs, outputs, updates, givens):
for inpt in inputs:
if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
self.inputs = inputs
self.input_names = {inp.name.split("/")[-1].split(":")[0]: inp for inp in inputs}
updates = updates or []
self.update_group = tf.group(*updates)
self.outputs_update = list(outputs) + [self.update_group]
self.givens = {} if givens is None else givens
def _feed_input(self, feed_dict, inpt, value):
if hasattr(inpt, 'make_feed_dict'):
feed_dict.update(inpt.make_feed_dict(value))
else:
feed_dict[inpt] = adjust_shape(inpt, value)
def __call__(self, *args, **kwargs):
assert len(args) + len(kwargs) <= len(self.inputs), "Too many arguments provided"
feed_dict = {}
# Update feed dict with givens.
for inpt in self.givens:
feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt]))
# Update the args
for inpt, value in zip(self.inputs, args):
self._feed_input(feed_dict, inpt, value)
for inpt_name, value in kwargs.items():
self._feed_input(feed_dict, self.input_names[inpt_name], value)
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
return results
# ================================================================
# Flat vectors
# ================================================================
@@ -227,8 +96,7 @@ def numel(x):
def intprod(x):
return int(np.prod(x))
def flatgrad(loss, var_list, clip_norm=None):
grads = tf.gradients(loss, var_list)
def flatgrad(grads, var_list, clip_norm=None):
if clip_norm is not None:
grads = [tf.clip_by_norm(grad, clip_norm=clip_norm) for grad in grads]
return tf.concat(axis=0, values=[
@@ -238,151 +106,40 @@ def flatgrad(loss, var_list, clip_norm=None):
class SetFromFlat(object):
def __init__(self, var_list, dtype=tf.float32):
assigns = []
shapes = list(map(var_shape, var_list))
total_size = np.sum([intprod(shape) for shape in shapes])
self.theta = theta = tf.placeholder(dtype, [total_size])
start = 0
assigns = []
for (shape, v) in zip(shapes, var_list):
size = intprod(shape)
assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], shape)))
start += size
self.op = tf.group(*assigns)
self.shapes = list(map(var_shape, var_list))
self.total_size = np.sum([intprod(shape) for shape in self.shapes])
self.var_list = var_list
def __call__(self, theta):
tf.get_default_session().run(self.op, feed_dict={self.theta: theta})
start = 0
for (shape, v) in zip(self.shapes, self.var_list):
size = intprod(shape)
v.assign(tf.reshape(theta[start:start + size], shape))
start += size
class GetFlat(object):
def __init__(self, var_list):
self.op = tf.concat(axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list])
self.var_list = var_list
def __call__(self):
return tf.get_default_session().run(self.op)
return tf.concat(axis=0, values=[tf.reshape(v, [numel(v)]) for v in self.var_list]).numpy()
def flattenallbut0(x):
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
# =============================================================
# TF placeholders management
# ============================================================
_PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape)
def get_placeholder(name, dtype, shape):
if name in _PLACEHOLDER_CACHE:
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
if out.graph == tf.get_default_graph():
assert dtype1 == dtype and shape1 == shape, \
'Placeholder with name {} has already been registered and has shape {}, different from requested {}'.format(name, shape1, shape)
return out
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
return out
def get_placeholder_cached(name):
return _PLACEHOLDER_CACHE[name][0]
# ================================================================
# Diagnostics
# Shape adjustment for feeding into tf tensors
# ================================================================
def display_var_info(vars):
from baselines import logger
count_params = 0
for v in vars:
name = v.name
if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
v_params = np.prod(v.shape.as_list())
count_params += v_params
if "/b:" in name or "/bias" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape)))
logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))
def get_available_gpus(session_config=None):
# based on recipe from https://stackoverflow.com/a/38580201
# Unless we allocate a session here, subsequent attempts to create one
# will ignore our custom config (in particular, allow_growth=True will have
# no effect).
if session_config is None:
session_config = get_session()._config
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices(session_config)
return [x.name for x in local_device_protos if x.device_type == 'GPU']
# ================================================================
# Saving variables
# ================================================================
def load_state(fname, sess=None):
from baselines import logger
logger.warn('load_state method is deprecated, please use load_variables instead')
sess = sess or get_session()
saver = tf.train.Saver()
saver.restore(tf.get_default_session(), fname)
def save_state(fname, sess=None):
from baselines import logger
logger.warn('save_state method is deprecated, please use save_variables instead')
sess = sess or get_session()
dirname = os.path.dirname(fname)
if any(dirname):
os.makedirs(dirname, exist_ok=True)
saver = tf.train.Saver()
saver.save(tf.get_default_session(), fname)
# The methods above and below are clearly doing the same thing, and in a rather similar way
# TODO: ensure there is no subtle differences and remove one
def save_variables(save_path, variables=None, sess=None):
import joblib
sess = sess or get_session()
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
ps = sess.run(variables)
save_dict = {v.name: value for v, value in zip(variables, ps)}
dirname = os.path.dirname(save_path)
if any(dirname):
os.makedirs(dirname, exist_ok=True)
joblib.dump(save_dict, save_path)
def load_variables(load_path, variables=None, sess=None):
import joblib
sess = sess or get_session()
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
loaded_params = joblib.load(os.path.expanduser(load_path))
restores = []
if isinstance(loaded_params, list):
assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)'
for d, v in zip(loaded_params, variables):
restores.append(v.assign(d))
else:
for v in variables:
restores.append(v.assign(loaded_params[v.name]))
sess.run(restores)
# ================================================================
# Shape adjustment for feeding into tf placeholders
# ================================================================
def adjust_shape(placeholder, data):
def adjust_shape(input_tensor, data):
'''
adjust shape of the data to the shape of the placeholder if possible.
adjust shape of the data to the shape of the tensor if possible.
If shape is incompatible, AssertionError is thrown
Parameters:
placeholder tensorflow input placeholder
input_tensor tensorflow input tensor
data input data to be (potentially) reshaped to be fed into placeholder
data input data to be (potentially) reshaped to be fed into input
Returns:
reshaped data
@@ -393,24 +150,23 @@ def adjust_shape(placeholder, data):
if isinstance(data, list):
data = np.array(data)
placeholder_shape = [x or -1 for x in placeholder.shape.as_list()]
input_shape = [x or -1 for x in input_tensor.shape.as_list()]
assert _check_shape(placeholder_shape, data.shape), \
'Shape of data {} is not compatible with shape of the placeholder {}'.format(data.shape, placeholder_shape)
assert _check_shape(input_shape, data.shape), \
'Shape of data {} is not compatible with shape of the input {}'.format(data.shape, input_shape)
return np.reshape(data, placeholder_shape)
return np.reshape(data, input_shape)
def _check_shape(placeholder_shape, data_shape):
def _check_shape(input_shape, data_shape):
''' check if two shapes are compatible (i.e. differ only by dimensions of size 1, or by the batch dimension)'''
return True
squeezed_placeholder_shape = _squeeze_shape(placeholder_shape)
squeezed_input_shape = _squeeze_shape(input_shape)
squeezed_data_shape = _squeeze_shape(data_shape)
for i, s_data in enumerate(squeezed_data_shape):
s_placeholder = squeezed_placeholder_shape[i]
if s_placeholder != -1 and s_data != s_placeholder:
s_input = squeezed_input_shape[i]
if s_input != -1 and s_data != s_input:
return False
return True

View File

@@ -70,11 +70,9 @@ class ShmemVecEnv(VecEnv):
assert len(actions) == len(self.parent_pipes)
for pipe, act in zip(self.parent_pipes, actions):
pipe.send(('step', act))
self.waiting_step = True
def step_wait(self):
outs = [pipe.recv() for pipe in self.parent_pipes]
self.waiting_step = False
obs, rews, dones, infos = zip(*outs)
return self._decode_obses(obs), np.array(rews), np.array(dones), infos

View File

@@ -38,9 +38,6 @@ def obs_space_info(obs_space):
if isinstance(obs_space, gym.spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict)
subspaces = obs_space.spaces
elif isinstance(obs_space, gym.spaces.Tuple):
assert isinstance(obs_space.spaces, tuple)
subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
else:
subspaces = {None: obs_space}
keys = []

View File

@@ -146,8 +146,8 @@ class VecEnvWrapper(VecEnv):
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
super().__init__(num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)
def step_async(self, actions):
self.venv.step_async(actions)
@@ -169,11 +169,6 @@ class VecEnvWrapper(VecEnv):
def get_images(self):
return self.venv.get_images()
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.venv, name)
class VecEnvObservationWrapper(VecEnvWrapper):
@abstractmethod
def process(self, obs):

View File

@@ -5,18 +5,16 @@ import time
from collections import deque
class VecMonitor(VecEnvWrapper):
def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
def __init__(self, venv, filename=None, keep_buf=0):
VecEnvWrapper.__init__(self, venv)
self.eprets = None
self.eplens = None
self.epcount = 0
self.tstart = time.time()
if filename:
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart},
extra_keys=info_keywords)
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
else:
self.results_writer = None
self.info_keywords = info_keywords
self.keep_buf = keep_buf
if self.keep_buf:
self.epret_buf = deque([], maxlen=keep_buf)
@@ -32,16 +30,11 @@ class VecMonitor(VecEnvWrapper):
obs, rews, dones, infos = self.venv.step_wait()
self.eprets += rews
self.eplens += 1
newinfos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
ret = self.eprets[i]
eplen = self.eplens[i]
newinfos = []
for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)):
info = info.copy()
if done:
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
for k in self.info_keywords:
epinfo[k] = info[k]
info['episode'] = epinfo
if self.keep_buf:
self.epret_buf.append(ret)
@@ -51,5 +44,6 @@ class VecMonitor(VecEnvWrapper):
self.eplens[i] = 0
if self.results_writer:
self.results_writer.write_row(epinfo)
newinfos[i] = info
newinfos.append(info)
return obs, rews, dones, newinfos

View File

@@ -1,22 +1,18 @@
from . import VecEnvWrapper
from baselines.common.running_mean_std import RunningMeanStd
import numpy as np
class VecNormalize(VecEnvWrapper):
"""
A vectorized wrapper that normalizes the observations
and returns from an environment.
"""
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8, use_tf=False):
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
VecEnvWrapper.__init__(self, venv)
if use_tf:
from baselines.common.running_mean_std import TfRunningMeanStd
self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='ob_rms') if ob else None
self.ret_rms = TfRunningMeanStd(shape=(), scope='ret_rms') if ret else None
else:
from baselines.common.running_mean_std import RunningMeanStd
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=()) if ret else None
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=()) if ret else None
self.clipob = clipob
self.cliprew = cliprew
self.ret = np.zeros(self.num_envs)

View File

@@ -1,5 +1,6 @@
from .vec_env import VecEnvObservationWrapper
class VecExtractDictObs(VecEnvObservationWrapper):
def __init__(self, venv, key):
self.key = key
@@ -7,4 +8,4 @@ class VecExtractDictObs(VecEnvObservationWrapper):
observation_space=venv.observation_space.spaces[self.key])
def process(self, obs):
return obs[self.key]
return obs[self.key]

View File

@@ -16,14 +16,4 @@ class TimeLimit(gym.Wrapper):
def reset(self, **kwargs):
self._elapsed_steps = 0
return self.env.reset(**kwargs)
class ClipActionsWrapper(gym.Wrapper):
def step(self, action):
import numpy as np
action = np.nan_to_num(action)
action = np.clip(action, self.action_space.low, self.action_space.high)
return self.env.step(action)
def reset(self, **kwargs):
return self.env.reset(**kwargs)
return self.env.reset(**kwargs)

View File

@@ -1,4 +1,5 @@
import os
import os.path as osp
import time
from collections import deque
import pickle
@@ -8,9 +9,9 @@ from baselines.ddpg.models import Actor, Critic
from baselines.ddpg.memory import Memory
from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from baselines.common import set_global_seeds
import baselines.common.tf_util as U
from baselines import logger
import tensorflow as tf
import numpy as np
try:
@@ -42,6 +43,7 @@ def learn(network, env,
tau=0.01,
eval_env=None,
param_noise_adaption_interval=50,
load_path=None,
**network_kwargs):
set_global_seeds(seed)
@@ -61,8 +63,8 @@ def learn(network, env,
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape)
critic = Critic(network=network, **network_kwargs)
actor = Actor(nb_actions, network=network, **network_kwargs)
critic = Critic(nb_actions, ob_shape=env.observation_space.shape, network=network, **network_kwargs)
actor = Actor(nb_actions, ob_shape=env.observation_space.shape, network=network, **network_kwargs)
action_noise = None
param_noise = None
@@ -94,12 +96,18 @@ def learn(network, env,
logger.info('Using agent with the following configuration:')
logger.info(str(agent.__dict__.items()))
if load_path is not None:
load_path = osp.expanduser(load_path)
ckpt = tf.train.Checkpoint(model=agent)
manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None)
ckpt.restore(manager.latest_checkpoint)
print("Restoring from {}".format(manager.latest_checkpoint))
eval_episode_rewards_history = deque(maxlen=100)
episode_rewards_history = deque(maxlen=100)
sess = U.get_session()
# Prepare everything.
agent.initialize(sess)
sess.graph.finalize()
agent.initialize()
agent.reset()
@@ -133,7 +141,8 @@ def learn(network, env,
agent.reset()
for t_rollout in range(nb_rollout_steps):
# Predict next action.
action, q, _, _ = agent.step(obs, apply_noise=True, compute_Q=True)
action, q, _, _ = agent.step(tf.constant(obs), apply_noise=True, compute_Q=True)
action, q = action.numpy(), q.numpy()
# Execute next action.
if rank == 0 and render:
@@ -170,7 +179,6 @@ def learn(network, env,
agent.reset()
# Train.
epoch_actor_losses = []
epoch_critic_losses = []
@@ -178,7 +186,9 @@ def learn(network, env,
for t_train in range(nb_train_steps):
# Adapt param noise, if necessary.
if memory.nb_entries >= batch_size and t_train % param_noise_adaption_interval == 0:
distance = agent.adapt_param_noise()
batch = agent.memory.sample(batch_size=batch_size)
obs0 = tf.constant(batch['obs0'])
distance = agent.adapt_param_noise(obs0)
epoch_adaptive_distances.append(distance)
cl, al = agent.train()

View File

@@ -1,16 +1,15 @@
from copy import copy
from functools import reduce
import numpy as np
import tensorflow as tf
import tensorflow.contrib as tc
from baselines import logger
from baselines.common.mpi_adam import MpiAdam
import baselines.common.tf_util as U
from baselines.ddpg.models import Actor, Critic
from baselines.common.mpi_running_mean_std import RunningMeanStd
try:
from mpi4py import MPI
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
from baselines.common.mpi_util import sync_from_root
except ImportError:
MPI = None
@@ -25,6 +24,7 @@ def denormalize(x, stats):
return x
return x * stats.std + stats.mean
@tf.function
def reduce_std(x, axis=None, keepdims=False):
return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))
@@ -33,49 +33,21 @@ def reduce_var(x, axis=None, keepdims=False):
devs_squared = tf.square(x - m)
return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)
def get_target_updates(vars, target_vars, tau):
logger.info('setting up target updates ...')
soft_updates = []
init_updates = []
assert len(vars) == len(target_vars)
for var, target_var in zip(vars, target_vars):
logger.info(' {} <- {}'.format(target_var.name, var.name))
init_updates.append(tf.assign(target_var, var))
soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var))
assert len(init_updates) == len(vars)
assert len(soft_updates) == len(vars)
return tf.group(*init_updates), tf.group(*soft_updates)
@tf.function
def update_perturbed_actor(actor, perturbed_actor, param_noise_stddev):
def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev):
assert len(actor.vars) == len(perturbed_actor.vars)
assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)
updates = []
for var, perturbed_var in zip(actor.vars, perturbed_actor.vars):
for var, perturbed_var in zip(actor.variables, perturbed_actor.variables):
if var in actor.perturbable_vars:
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
perturbed_var.assign(var + tf.random.normal(shape=tf.shape(var), mean=0., stddev=param_noise_stddev))
else:
logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
updates.append(tf.assign(perturbed_var, var))
assert len(updates) == len(actor.vars)
return tf.group(*updates)
perturbed_var.assign(var)
class DDPG(object):
class DDPG(tf.Module):
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
# Inputs.
self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
self.obs1 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs1')
self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
self.actions = tf.placeholder(tf.float32, shape=(None,) + action_shape, name='actions')
self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')
# Parameters.
self.gamma = gamma
@@ -88,128 +60,103 @@ class DDPG(object):
self.action_range = action_range
self.return_range = return_range
self.observation_range = observation_range
self.observation_shape = observation_shape
self.critic = critic
self.actor = actor
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.clip_norm = clip_norm
self.enable_popart = enable_popart
self.reward_scale = reward_scale
self.batch_size = batch_size
self.stats_sample = None
self.critic_l2_reg = critic_l2_reg
self.actor_lr = tf.constant(actor_lr)
self.critic_lr = tf.constant(critic_lr)
# Observation normalization.
if self.normalize_observations:
with tf.variable_scope('obs_rms'):
with tf.name_scope('obs_rms'):
self.obs_rms = RunningMeanStd(shape=observation_shape)
else:
self.obs_rms = None
normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms),
self.observation_range[0], self.observation_range[1])
normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms),
self.observation_range[0], self.observation_range[1])
# Return normalization.
if self.normalize_returns:
with tf.variable_scope('ret_rms'):
with tf.name_scope('ret_rms'):
self.ret_rms = RunningMeanStd()
else:
self.ret_rms = None
# Create target networks.
target_actor = copy(actor)
target_actor.name = 'target_actor'
self.target_actor = target_actor
target_critic = copy(critic)
target_critic.name = 'target_critic'
self.target_critic = target_critic
# Create networks and core TF parts that are shared across setup parts.
self.actor_tf = actor(normalized_obs0)
self.normalized_critic_tf = critic(normalized_obs0, self.actions)
self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True)
self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)
self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1
self.target_critic = Critic(actor.nb_actions, observation_shape, name='target_critic', network=critic.network, **critic.network_kwargs)
self.target_actor = Actor(actor.nb_actions, observation_shape, name='target_actor', network=actor.network, **actor.network_kwargs)
# Set up parts.
if self.param_noise is not None:
self.setup_param_noise(normalized_obs0)
self.setup_actor_optimizer()
self.setup_critic_optimizer()
if self.normalize_returns and self.enable_popart:
self.setup_popart()
self.setup_stats()
self.setup_target_network_updates()
self.setup_param_noise()
self.initial_state = None # recurrent architectures not supported yet
if MPI is not None:
comm = MPI.COMM_WORLD
self.actor_optimizer = MpiAdamOptimizer(comm, self.actor.trainable_variables)
self.critic_optimizer = MpiAdamOptimizer(comm, self.critic.trainable_variables)
else:
self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)
self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)
def setup_target_network_updates(self):
actor_init_updates, actor_soft_updates = get_target_updates(self.actor.vars, self.target_actor.vars, self.tau)
critic_init_updates, critic_soft_updates = get_target_updates(self.critic.vars, self.target_critic.vars, self.tau)
self.target_init_updates = [actor_init_updates, critic_init_updates]
self.target_soft_updates = [actor_soft_updates, critic_soft_updates]
def setup_param_noise(self, normalized_obs0):
assert self.param_noise is not None
# Configure perturbed actor.
param_noise_actor = copy(self.actor)
param_noise_actor.name = 'param_noise_actor'
self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
logger.info('setting up param noise')
self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)
# Configure separate copy for stddev adoption.
adaptive_param_noise_actor = copy(self.actor)
adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
def setup_actor_optimizer(self):
logger.info('setting up actor optimizer')
self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars]
actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_variables]
actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
logger.info(' actor shapes: {}'.format(actor_shapes))
logger.info(' actor params: {}'.format(actor_nb_params))
self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm)
self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
beta1=0.9, beta2=0.999, epsilon=1e-08)
def setup_critic_optimizer(self):
logger.info('setting up critic optimizer')
normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
if self.critic_l2_reg > 0.:
critic_reg_vars = [var for var in self.critic.trainable_vars if var.name.endswith('/w:0') and 'output' not in var.name]
for var in critic_reg_vars:
logger.info(' regularizing: {}'.format(var.name))
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
critic_reg = tc.layers.apply_regularization(
tc.layers.l2_regularizer(self.critic_l2_reg),
weights_list=critic_reg_vars
)
self.critic_loss += critic_reg
critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars]
critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_variables]
critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
logger.info(' critic shapes: {}'.format(critic_shapes))
logger.info(' critic params: {}'.format(critic_nb_params))
self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm)
self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
beta1=0.9, beta2=0.999, epsilon=1e-08)
if self.critic_l2_reg > 0.:
critic_reg_vars = []
for layer in self.critic.network_builder.layers[1:]:
critic_reg_vars.append(layer.kernel)
for var in critic_reg_vars:
logger.info(' regularizing: {}'.format(var.name))
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
logger.info('setting up critic target updates ...')
for var, target_var in zip(self.critic.variables, self.target_critic.variables):
logger.info(' {} <- {}'.format(target_var.name, var.name))
logger.info('setting up actor target updates ...')
for var, target_var in zip(self.actor.variables, self.target_actor.variables):
logger.info(' {} <- {}'.format(target_var.name, var.name))
if self.param_noise:
logger.info('setting up param noise')
for var, perturbed_var in zip(self.actor.variables, self.perturbed_actor.variables):
if var in actor.perturbable_vars:
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
else:
logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
for var, perturbed_var in zip(self.actor.variables, self.perturbed_adaptive_actor.variables):
if var in actor.perturbable_vars:
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
else:
logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
if self.normalize_returns and self.enable_popart:
self.setup_popart()
self.initial_state = None # recurrent architectures not supported yet
def setup_param_noise(self):
assert self.param_noise is not None
# Configure perturbed actor.
self.perturbed_actor = Actor(self.actor.nb_actions, self.observation_shape, name='param_noise_actor', network=self.actor.network, **self.actor.network_kwargs)
# Configure separate copy for stddev adoption.
self.perturbed_adaptive_actor = Actor(self.actor.nb_actions, self.observation_shape, name='adaptive_param_noise_actor', network=self.actor.network, **self.actor.network_kwargs)
def setup_popart(self):
# See https://arxiv.org/pdf/1602.07714.pdf for details.
self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
new_std = self.ret_rms.std
self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
new_mean = self.ret_rms.mean
self.renormalize_Q_outputs_op = []
for vs in [self.critic.output_vars, self.target_critic.output_vars]:
assert len(vs) == 2
M, b = vs
@@ -217,63 +164,26 @@ class DDPG(object):
assert 'bias' in b.name
assert M.get_shape()[-1] == 1
assert b.get_shape()[-1] == 1
self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)]
self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)]
def setup_stats(self):
ops = []
names = []
if self.normalize_returns:
ops += [self.ret_rms.mean, self.ret_rms.std]
names += ['ret_rms_mean', 'ret_rms_std']
if self.normalize_observations:
ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
names += ['obs_rms_mean', 'obs_rms_std']
ops += [tf.reduce_mean(self.critic_tf)]
names += ['reference_Q_mean']
ops += [reduce_std(self.critic_tf)]
names += ['reference_Q_std']
ops += [tf.reduce_mean(self.critic_with_actor_tf)]
names += ['reference_actor_Q_mean']
ops += [reduce_std(self.critic_with_actor_tf)]
names += ['reference_actor_Q_std']
ops += [tf.reduce_mean(self.actor_tf)]
names += ['reference_action_mean']
ops += [reduce_std(self.actor_tf)]
names += ['reference_action_std']
if self.param_noise:
ops += [tf.reduce_mean(self.perturbed_actor_tf)]
names += ['reference_perturbed_action_mean']
ops += [reduce_std(self.perturbed_actor_tf)]
names += ['reference_perturbed_action_std']
self.stats_ops = ops
self.stats_names = names
@tf.function
def step(self, obs, apply_noise=True, compute_Q=True):
normalized_obs = tf.clip_by_value(normalize(obs, self.obs_rms), self.observation_range[0], self.observation_range[1])
actor_tf = self.actor(normalized_obs)
if self.param_noise is not None and apply_noise:
actor_tf = self.perturbed_actor_tf
action = self.perturbed_actor(normalized_obs)
else:
actor_tf = self.actor_tf
feed_dict = {self.obs0: U.adjust_shape(self.obs0, [obs])}
action = actor_tf
if compute_Q:
action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
normalized_critic_with_actor_tf = self.critic(normalized_obs, actor_tf)
q = denormalize(tf.clip_by_value(normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
else:
action = self.sess.run(actor_tf, feed_dict=feed_dict)
q = None
if self.action_noise is not None and apply_noise:
noise = self.action_noise()
assert noise.shape == action[0].shape
action += noise
action = np.clip(action, self.action_range[0], self.action_range[1])
action = tf.clip_by_value(action, self.action_range[0], self.action_range[1])
return action, q, None, None
@@ -287,79 +197,130 @@ class DDPG(object):
self.obs_rms.update(np.array([obs0[b]]))
def train(self):
# Get a batch.
batch = self.memory.sample(batch_size=self.batch_size)
obs0, obs1 = tf.constant(batch['obs0']), tf.constant(batch['obs1'])
actions, rewards, terminals1 = tf.constant(batch['actions']), tf.constant(batch['rewards']), tf.constant(batch['terminals1'], dtype=tf.float32)
normalized_obs0, target_Q = self.compute_normalized_obs0_and_target_Q(obs0, obs1, rewards, terminals1)
if self.normalize_returns and self.enable_popart:
old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict={
self.obs1: batch['obs1'],
self.rewards: batch['rewards'],
self.terminals1: batch['terminals1'].astype('float32'),
})
old_mean = self.ret_rms.mean
old_std = self.ret_rms.std
self.ret_rms.update(target_Q.flatten())
self.sess.run(self.renormalize_Q_outputs_op, feed_dict={
self.old_std : np.array([old_std]),
self.old_mean : np.array([old_mean]),
})
# renormalize Q outputs
new_mean = self.ret_rms.mean
new_std = self.ret_rms.std
for vs in [self.critic.output_vars, self.target_critic.output_vars]:
kernel, bias = vs
kernel.assign(kernel * old_std / new_std)
bias.assign((bias * old_std + old_mean - new_mean) / new_std)
# Run sanity check. Disabled by default since it slows down things considerably.
# print('running sanity check')
# target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
# self.obs1: batch['obs1'],
# self.rewards: batch['rewards'],
# self.terminals1: batch['terminals1'].astype('float32'),
# })
# print(target_Q_new, target_Q, new_mean, new_std)
# assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
actor_grads, actor_loss = self.get_actor_grads(normalized_obs0)
critic_grads, critic_loss = self.get_critic_grads(normalized_obs0, actions, target_Q)
if MPI is not None:
self.actor_optimizer.apply_gradients(actor_grads, self.actor_lr)
self.critic_optimizer.apply_gradients(critic_grads, self.critic_lr)
else:
target_Q = self.sess.run(self.target_Q, feed_dict={
self.obs1: batch['obs1'],
self.rewards: batch['rewards'],
self.terminals1: batch['terminals1'].astype('float32'),
})
# Get all gradients and perform a synced update.
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={
self.obs0: batch['obs0'],
self.actions: batch['actions'],
self.critic_target: target_Q,
})
self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)
self.actor_optimizer.apply_gradients(zip(actor_grads, self.actor.trainable_variables))
self.critic_optimizer.apply_gradients(zip(critic_grads, self.critic.trainable_variables))
return critic_loss, actor_loss
def initialize(self, sess):
self.sess = sess
self.sess.run(tf.global_variables_initializer())
self.actor_optimizer.sync()
self.critic_optimizer.sync()
self.sess.run(self.target_init_updates)
@tf.function
def compute_normalized_obs0_and_target_Q(self, obs0, obs1, rewards, terminals1):
normalized_obs0 = tf.clip_by_value(normalize(obs0, self.obs_rms), self.observation_range[0], self.observation_range[1])
normalized_obs1 = tf.clip_by_value(normalize(obs1, self.obs_rms), self.observation_range[0], self.observation_range[1])
Q_obs1 = denormalize(self.target_critic(normalized_obs1, self.target_actor(normalized_obs1)), self.ret_rms)
target_Q = rewards + (1. - terminals1) * self.gamma * Q_obs1
return normalized_obs0, target_Q
@tf.function
def get_actor_grads(self, normalized_obs0):
with tf.GradientTape() as tape:
actor_tf = self.actor(normalized_obs0)
normalized_critic_with_actor_tf = self.critic(normalized_obs0, actor_tf)
critic_with_actor_tf = denormalize(tf.clip_by_value(normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
actor_loss = -tf.reduce_mean(critic_with_actor_tf)
actor_grads = tape.gradient(actor_loss, self.actor.trainable_variables)
if self.clip_norm:
actor_grads = [tf.clip_by_norm(grad, clip_norm=self.clip_norm) for grad in actor_grads]
if MPI is not None:
actor_grads = tf.concat([tf.reshape(g, (-1,)) for g in actor_grads], axis=0)
return actor_grads, actor_loss
@tf.function
def get_critic_grads(self, normalized_obs0, actions, target_Q):
with tf.GradientTape() as tape:
normalized_critic_tf = self.critic(normalized_obs0, actions)
normalized_critic_target_tf = tf.clip_by_value(normalize(target_Q, self.ret_rms), self.return_range[0], self.return_range[1])
critic_loss = tf.reduce_mean(tf.square(normalized_critic_tf - normalized_critic_target_tf))
# The first is input layer, which is ignored here.
if self.critic_l2_reg > 0.:
# Ignore the first input layer.
for layer in self.critic.network_builder.layers[1:]:
# The original l2_regularizer takes half of sum square.
critic_loss += (self.critic_l2_reg / 2.)* tf.reduce_sum(tf.square(layer.kernel))
critic_grads = tape.gradient(critic_loss, self.critic.trainable_variables)
if self.clip_norm:
critic_grads = [tf.clip_by_norm(grad, clip_norm=self.clip_norm) for grad in critic_grads]
if MPI is not None:
critic_grads = tf.concat([tf.reshape(g, (-1,)) for g in critic_grads], axis=0)
return critic_grads, critic_loss
def initialize(self):
if MPI is not None:
sync_from_root(self.actor.trainable_variables + self.critic.trainable_variables)
self.target_actor.set_weights(self.actor.get_weights())
self.target_critic.set_weights(self.critic.get_weights())
@tf.function
def update_target_net(self):
self.sess.run(self.target_soft_updates)
for var, target_var in zip(self.actor.variables, self.target_actor.variables):
target_var.assign((1. - self.tau) * target_var + self.tau * var)
for var, target_var in zip(self.critic.variables, self.target_critic.variables):
target_var.assign((1. - self.tau) * target_var + self.tau * var)
def get_stats(self):
if self.stats_sample is None:
# Get a sample and keep that fixed for all further computations.
# This allows us to estimate the change in value for the same set of inputs.
self.stats_sample = self.memory.sample(batch_size=self.batch_size)
values = self.sess.run(self.stats_ops, feed_dict={
self.obs0: self.stats_sample['obs0'],
self.actions: self.stats_sample['actions'],
})
obs0 = self.stats_sample['obs0']
actions = self.stats_sample['actions']
normalized_obs0 = tf.clip_by_value(normalize(obs0, self.obs_rms), self.observation_range[0], self.observation_range[1])
normalized_critic_tf = self.critic(normalized_obs0, actions)
critic_tf = denormalize(tf.clip_by_value(normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
actor_tf = self.actor(normalized_obs0)
normalized_critic_with_actor_tf = self.critic(normalized_obs0, actor_tf)
critic_with_actor_tf = denormalize(tf.clip_by_value(normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
names = self.stats_names[:]
assert len(names) == len(values)
stats = dict(zip(names, values))
if self.param_noise is not None:
stats = {**stats, **self.param_noise.get_stats()}
stats = {}
if self.normalize_returns:
stats['ret_rms_mean'] = self.ret_rms.mean
stats['ret_rms_std'] = self.ret_rms.std
if self.normalize_observations:
stats['obs_rms_mean'] = tf.reduce_mean(self.obs_rms.mean)
stats['obs_rms_std'] = tf.reduce_mean(self.obs_rms.std)
stats['reference_Q_mean'] = tf.reduce_mean(critic_tf)
stats['reference_Q_std'] = reduce_std(critic_tf)
stats['reference_actor_Q_mean'] = tf.reduce_mean(critic_with_actor_tf)
stats['reference_actor_Q_std'] = reduce_std(critic_with_actor_tf)
stats['reference_action_mean'] = tf.reduce_mean(actor_tf)
stats['reference_action_std'] = reduce_std(actor_tf)
if self.param_noise:
perturbed_actor_tf = self.perturbed_actor(normalized_obs0)
stats['reference_perturbed_action_mean'] = tf.reduce_mean(perturbed_actor_tf)
stats['reference_perturbed_action_std'] = reduce_std(perturbed_actor_tf)
stats.update(self.param_noise.get_stats())
return stats
def adapt_param_noise(self):
def adapt_param_noise(self, obs0):
try:
from mpi4py import MPI
except ImportError:
@@ -368,34 +329,28 @@ class DDPG(object):
if self.param_noise is None:
return 0.
# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
batch = self.memory.sample(batch_size=self.batch_size)
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
self.param_noise_stddev: self.param_noise.current_stddev,
})
distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
self.obs0: batch['obs0'],
self.param_noise_stddev: self.param_noise.current_stddev,
})
mean_distance = self.get_mean_distance(obs0).numpy()
if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else:
mean_distance = distance
if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else:
mean_distance = distance
mean_distance = MPI.COMM_WORLD.allreduce(mean_distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
self.param_noise.adapt(mean_distance)
return mean_distance
@tf.function
def get_mean_distance(self, obs0):
# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
update_perturbed_actor(self.actor, self.perturbed_adaptive_actor, self.param_noise.current_stddev)
normalized_obs0 = tf.clip_by_value(normalize(obs0, self.obs_rms), self.observation_range[0], self.observation_range[1])
actor_tf = self.actor(normalized_obs0)
adaptive_actor_tf = self.perturbed_adaptive_actor(normalized_obs0)
mean_distance = tf.sqrt(tf.reduce_mean(tf.square(actor_tf - adaptive_actor_tf)))
return mean_distance
def reset(self):
# Reset internal state after an episode is complete.
if self.action_noise is not None:
self.action_noise.reset()
if self.param_noise is not None:
self.sess.run(self.perturb_policy_ops, feed_dict={
self.param_noise_stddev: self.param_noise.current_stddev,
})
update_perturbed_actor(self.actor, self.perturbed_actor, self.param_noise.current_stddev)

View File

@@ -2,50 +2,48 @@ import tensorflow as tf
from baselines.common.models import get_network_builder
class Model(object):
class Model(tf.keras.Model):
def __init__(self, name, network='mlp', **network_kwargs):
self.name = name
self.network_builder = get_network_builder(network)(**network_kwargs)
@property
def vars(self):
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
@property
def trainable_vars(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
super(Model, self).__init__(name=name)
self.network = network
self.network_kwargs = network_kwargs
@property
def perturbable_vars(self):
return [var for var in self.trainable_vars if 'LayerNorm' not in var.name]
return [var for var in self.trainable_variables if 'layer_normalization' not in var.name]
class Actor(Model):
def __init__(self, nb_actions, name='actor', network='mlp', **network_kwargs):
def __init__(self, nb_actions, ob_shape, name='actor', network='mlp', **network_kwargs):
super().__init__(name=name, network=network, **network_kwargs)
self.nb_actions = nb_actions
self.network_builder = get_network_builder(network)(**network_kwargs)(ob_shape)
self.output_layer = tf.keras.layers.Dense(units=self.nb_actions,
activation=tf.keras.activations.tanh,
kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
_ = self.output_layer(self.network_builder.outputs[0])
def __call__(self, obs, reuse=False):
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
x = self.network_builder(obs)
x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
x = tf.nn.tanh(x)
return x
@tf.function
def call(self, obs):
return self.output_layer(self.network_builder(obs))
class Critic(Model):
def __init__(self, name='critic', network='mlp', **network_kwargs):
def __init__(self, nb_actions, ob_shape, name='critic', network='mlp', **network_kwargs):
super().__init__(name=name, network=network, **network_kwargs)
self.layer_norm = True
self.network_builder = get_network_builder(network)(**network_kwargs)((ob_shape[0] + nb_actions,))
self.output_layer = tf.keras.layers.Dense(units=1,
kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3),
name='output')
_ = self.output_layer(self.network_builder.outputs[0])
def __call__(self, obs, action, reuse=False):
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated
x = self.network_builder(x)
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3), name='output')
return x
@tf.function
def call(self, obs, actions):
x = tf.concat([obs, actions], axis=-1) # this assumes observation and action can be concatenated
x = self.network_builder(x)
return self.output_layer(x)
@property
def output_vars(self):
output_vars = [var for var in self.trainable_vars if 'output' in var.name]
return output_vars
return self.output_layer.trainable_variables

View File

@@ -1,16 +0,0 @@
from baselines.common.tests.util import smoketest
def _run(argstr):
smoketest('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr)
def test_popart():
_run('--normalize_returns=True --popart=True')
def test_noise_normal():
_run('--noise_type=normal_0.1')
def test_noise_ou():
_run('--noise_type=ou_0.1')
def test_noise_adaptive():
_run('--noise_type=adaptive-param_0.2,normal_0.1')

View File

@@ -1,6 +1,6 @@
from baselines.deepq import models # noqa
from baselines.deepq.build_graph import build_act, build_train # noqa
from baselines.deepq.deepq import learn, load_act # noqa
from baselines.deepq.deepq_learner import DEEPQ #noqa
from baselines.deepq.deepq import learn # noqa
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa
def wrap_atari_dqn(env):

View File

@@ -1,449 +0,0 @@
"""Deep Q learning graph
The functions in this file can are used to create the following functions:
======= act ========
Function to chose an action given an observation
Parameters
----------
observation: object
Observation that can be feed into the output of make_obs_ph
stochastic: bool
if set to False all the actions are always deterministic (default False)
update_eps_ph: float
update epsilon a new value, if negative not update happens
(default: no update)
Returns
-------
Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be performed for
every element of the batch.
======= act (in case of parameter noise) ========
Function to chose an action given an observation
Parameters
----------
observation: object
Observation that can be feed into the output of make_obs_ph
stochastic: bool
if set to False all the actions are always deterministic (default False)
update_eps_ph: float
update epsilon to a new value, if negative no update happens
(default: no update)
reset_ph: bool
reset the perturbed policy by sampling a new perturbation
update_param_noise_threshold_ph: float
the desired threshold for the difference between non-perturbed and perturbed policy
update_param_noise_scale_ph: bool
whether or not to update the scale of the noise for the next time it is re-perturbed
Returns
-------
Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be performed for
every element of the batch.
======= train =======
Function that takes a transition (s,a,r,s') and optimizes Bellman equation's error:
td_error = Q(s,a) - (r + gamma * max_a' Q(s', a'))
loss = huber_loss[td_error]
Parameters
----------
obs_t: object
a batch of observations
action: np.array
actions that were selected upon seeing obs_t.
dtype must be int32 and shape must be (batch_size,)
reward: np.array
immediate reward attained after executing those actions
dtype must be float32 and shape must be (batch_size,)
obs_tp1: object
observations that followed obs_t
done: np.array
1 if obs_t was the last observation in the episode and 0 otherwise
obs_tp1 gets ignored, but must be of the valid shape.
dtype must be float32 and shape must be (batch_size,)
weight: np.array
imporance weights for every element of the batch (gradient is multiplied
by the importance weight) dtype must be float32 and shape must be (batch_size,)
Returns
-------
td_error: np.array
a list of differences between Q(s,a) and the target in Bellman's equation.
dtype is float32 and shape is (batch_size,)
======= update_target ========
copy the parameters from optimized Q function to the target Q function.
In Q learning we actually optimize the following error:
Q(s,a) - (r + gamma * max_a' Q'(s', a'))
Where Q' is lagging behind Q to stablize the learning. For example for Atari
Q' is set to Q once every 10000 updates training steps.
"""
import tensorflow as tf
import baselines.common.tf_util as U
def scope_vars(scope, trainable_only=False):
"""
Get variables inside a scope
The scope can be specified as a string
Parameters
----------
scope: str or VariableScope
scope in which the variables reside.
trainable_only: bool
whether or not to return only the variables that were marked as trainable.
Returns
-------
vars: [tf.Variable]
list of variables in `scope`.
"""
return tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.GLOBAL_VARIABLES,
scope=scope if isinstance(scope, str) else scope.name
)
def scope_name():
"""Returns the name of current scope as a string, e.g. deepq/q_func"""
return tf.get_variable_scope().name
def absolute_scope_name(relative_scope_name):
"""Appends parent scope name to `relative_scope_name`"""
return scope_name() + "/" + relative_scope_name
def default_param_noise_filter(var):
if var not in tf.trainable_variables():
# We never perturb non-trainable vars.
return False
if "fully_connected" in var.name:
# We perturb fully-connected layers.
return True
# The remaining layers are likely conv or layer norm layers, which we do not wish to
# perturb (in the former case because they only extract features, in the latter case because
# we use them for normalization purposes). If you change your network, you will likely want
# to re-consider which layers to perturb and which to keep untouched.
return False
def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
"""Creates the act function:
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that take a name and creates a placeholder of input with that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of every action.
num_actions: int
number of actions.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the scope must be given.
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given observation.
` See the top of the file for details.
"""
with tf.variable_scope(scope, reuse=reuse):
observations_ph = make_obs_ph("observation")
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0))
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(observations_ph.get())[0]
random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)
output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
_act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph],
outputs=output_actions,
givens={update_eps_ph: -1.0, stochastic_ph: True},
updates=[update_eps_expr])
def act(ob, stochastic=True, update_eps=-1):
return _act(ob, stochastic, update_eps)
return act
def build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None, param_noise_filter_func=None):
"""Creates the act function with support for parameter space noise exploration (https://arxiv.org/abs/1706.01905):
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that take a name and creates a placeholder of input with that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of every action.
num_actions: int
number of actions.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the scope must be given.
param_noise_filter_func: tf.Variable -> bool
function that decides whether or not a variable should be perturbed. Only applicable
if param_noise is True. If set to None, default_param_noise_filter is used by default.
Returns
-------
act: (tf.Variable, bool, float, bool, float, bool) -> tf.Variable
function to select and action given observation.
` See the top of the file for details.
"""
if param_noise_filter_func is None:
param_noise_filter_func = default_param_noise_filter
with tf.variable_scope(scope, reuse=reuse):
observations_ph = make_obs_ph("observation")
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
update_param_noise_threshold_ph = tf.placeholder(tf.float32, (), name="update_param_noise_threshold")
update_param_noise_scale_ph = tf.placeholder(tf.bool, (), name="update_param_noise_scale")
reset_ph = tf.placeholder(tf.bool, (), name="reset")
eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0))
param_noise_scale = tf.get_variable("param_noise_scale", (), initializer=tf.constant_initializer(0.01), trainable=False)
param_noise_threshold = tf.get_variable("param_noise_threshold", (), initializer=tf.constant_initializer(0.05), trainable=False)
# Unmodified Q.
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
# Perturbable Q used for the actual rollout.
q_values_perturbed = q_func(observations_ph.get(), num_actions, scope="perturbed_q_func")
# We have to wrap this code into a function due to the way tf.cond() works. See
# https://stackoverflow.com/questions/37063952/confused-by-the-behavior-of-tf-cond for
# a more detailed discussion.
def perturb_vars(original_scope, perturbed_scope):
all_vars = scope_vars(absolute_scope_name(original_scope))
all_perturbed_vars = scope_vars(absolute_scope_name(perturbed_scope))
assert len(all_vars) == len(all_perturbed_vars)
perturb_ops = []
for var, perturbed_var in zip(all_vars, all_perturbed_vars):
if param_noise_filter_func(perturbed_var):
# Perturb this variable.
op = tf.assign(perturbed_var, var + tf.random_normal(shape=tf.shape(var), mean=0., stddev=param_noise_scale))
else:
# Do not perturb, just assign.
op = tf.assign(perturbed_var, var)
perturb_ops.append(op)
assert len(perturb_ops) == len(all_vars)
return tf.group(*perturb_ops)
# Set up functionality to re-compute `param_noise_scale`. This perturbs yet another copy
# of the network and measures the effect of that perturbation in action space. If the perturbation
# is too big, reduce scale of perturbation, otherwise increase.
q_values_adaptive = q_func(observations_ph.get(), num_actions, scope="adaptive_q_func")
perturb_for_adaption = perturb_vars(original_scope="q_func", perturbed_scope="adaptive_q_func")
kl = tf.reduce_sum(tf.nn.softmax(q_values) * (tf.log(tf.nn.softmax(q_values)) - tf.log(tf.nn.softmax(q_values_adaptive))), axis=-1)
mean_kl = tf.reduce_mean(kl)
def update_scale():
with tf.control_dependencies([perturb_for_adaption]):
update_scale_expr = tf.cond(mean_kl < param_noise_threshold,
lambda: param_noise_scale.assign(param_noise_scale * 1.01),
lambda: param_noise_scale.assign(param_noise_scale / 1.01),
)
return update_scale_expr
# Functionality to update the threshold for parameter space noise.
update_param_noise_threshold_expr = param_noise_threshold.assign(tf.cond(update_param_noise_threshold_ph >= 0,
lambda: update_param_noise_threshold_ph, lambda: param_noise_threshold))
# Put everything together.
deterministic_actions = tf.argmax(q_values_perturbed, axis=1)
batch_size = tf.shape(observations_ph.get())[0]
random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)
output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
updates = [
update_eps_expr,
tf.cond(reset_ph, lambda: perturb_vars(original_scope="q_func", perturbed_scope="perturbed_q_func"), lambda: tf.group(*[])),
tf.cond(update_param_noise_scale_ph, lambda: update_scale(), lambda: tf.Variable(0., trainable=False)),
update_param_noise_threshold_expr,
]
_act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph, reset_ph, update_param_noise_threshold_ph, update_param_noise_scale_ph],
outputs=output_actions,
givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False},
updates=updates)
def act(ob, reset=False, update_param_noise_threshold=False, update_param_noise_scale=False, stochastic=True, update_eps=-1):
return _act(ob, stochastic, update_eps, reset, update_param_noise_threshold, update_param_noise_scale)
return act
def build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None, gamma=1.0,
double_q=True, scope="deepq", reuse=None, param_noise=False, param_noise_filter_func=None):
"""Creates the train function:
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that takes a name and creates a placeholder of input with that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of every action.
num_actions: int
number of actions
reuse: bool
whether or not to reuse the graph variables
optimizer: tf.train.Optimizer
optimizer to use for the Q-learning objective.
grad_norm_clipping: float or None
clip gradient norms to this value. If None no clipping is performed.
gamma: float
discount rate.
double_q: bool
if true will use Double Q Learning (https://arxiv.org/abs/1509.06461).
In general it is a good idea to keep it enabled.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the scope must be given.
param_noise: bool
whether or not to use parameter space noise (https://arxiv.org/abs/1706.01905)
param_noise_filter_func: tf.Variable -> bool
function that decides whether or not a variable should be perturbed. Only applicable
if param_noise is True. If set to None, default_param_noise_filter is used by default.
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given observation.
` See the top of the file for details.
train: (object, np.array, np.array, object, np.array, np.array) -> np.array
optimize the error in Bellman's equation.
` See the top of the file for details.
update_target: () -> ()
copy the parameters from optimized Q function to the target Q function.
` See the top of the file for details.
debug: {str: function}
a bunch of functions to print debug data like q_values.
"""
if param_noise:
act_f = build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse,
param_noise_filter_func=param_noise_filter_func)
else:
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse)
with tf.variable_scope(scope, reuse=reuse):
# set up placeholders
obs_t_input = make_obs_ph("obs_t")
act_t_ph = tf.placeholder(tf.int32, [None], name="action")
rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
obs_tp1_input = make_obs_ph("obs_tp1")
done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")
# q network evaluation
q_t = q_func(obs_t_input.get(), num_actions, scope="q_func", reuse=True) # reuse parameters from act
q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name + "/q_func")
# target q network evalution
q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func")
target_q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name + "/target_q_func")
# q scores for actions which we know were selected in the given state.
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), 1)
# compute estimate of best possible value starting from state at t + 1
if double_q:
q_tp1_using_online_net = q_func(obs_tp1_input.get(), num_actions, scope="q_func", reuse=True)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best = tf.reduce_sum(q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), 1)
else:
q_tp1_best = tf.reduce_max(q_tp1, 1)
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked
# compute the error (potentially clipped)
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
errors = U.huber_loss(td_error)
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
# compute optimization op (potentially with gradient clipping)
if grad_norm_clipping is not None:
gradients = optimizer.compute_gradients(weighted_error, var_list=q_func_vars)
for i, (grad, var) in enumerate(gradients):
if grad is not None:
gradients[i] = (tf.clip_by_norm(grad, grad_norm_clipping), var)
optimize_expr = optimizer.apply_gradients(gradients)
else:
optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars)
# update_target_fn will be called periodically to copy Q network to target Q network
update_target_expr = []
for var, var_target in zip(sorted(q_func_vars, key=lambda v: v.name),
sorted(target_q_func_vars, key=lambda v: v.name)):
update_target_expr.append(var_target.assign(var))
update_target_expr = tf.group(*update_target_expr)
# Create callable functions
train = U.function(
inputs=[
obs_t_input,
act_t_ph,
rew_t_ph,
obs_tp1_input,
done_mask_ph,
importance_weights_ph
],
outputs=td_error,
updates=[optimize_expr]
)
update_target = U.function([], [], updates=[update_target_expr])
q_values = U.function([obs_t_input], q_t)
return act_f, train, update_target, {'q_values': q_values}

View File

@@ -1,96 +1,19 @@
import os
import tempfile
import os.path as osp
import tensorflow as tf
import zipfile
import cloudpickle
import numpy as np
import baselines.common.tf_util as U
from baselines.common.tf_util import load_variables, save_variables
from baselines import logger
from baselines.common.schedules import LinearSchedule
from baselines.common.vec_env.vec_env import VecEnv
from baselines.common import set_global_seeds
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from baselines.deepq.utils import ObservationInput
from baselines.common.tf_util import get_session
from baselines.deepq.models import build_q_func
class ActWrapper(object):
def __init__(self, act, act_params):
self._act = act
self._act_params = act_params
self.initial_state = None
@staticmethod
def load_act(path):
with open(path, "rb") as f:
model_data, act_params = cloudpickle.load(f)
act = deepq.build_act(**act_params)
sess = tf.Session()
sess.__enter__()
with tempfile.TemporaryDirectory() as td:
arc_path = os.path.join(td, "packed.zip")
with open(arc_path, "wb") as f:
f.write(model_data)
zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
load_variables(os.path.join(td, "model"))
return ActWrapper(act, act_params)
def __call__(self, *args, **kwargs):
return self._act(*args, **kwargs)
def step(self, observation, **kwargs):
# DQN doesn't use RNNs so we ignore states and masks
kwargs.pop('S', None)
kwargs.pop('M', None)
return self._act([observation], **kwargs), None, None, None
def save_act(self, path=None):
"""Save model to a pickle located at `path`"""
if path is None:
path = os.path.join(logger.get_dir(), "model.pkl")
with tempfile.TemporaryDirectory() as td:
save_variables(os.path.join(td, "model"))
arc_name = os.path.join(td, "packed.zip")
with zipfile.ZipFile(arc_name, 'w') as zipf:
for root, dirs, files in os.walk(td):
for fname in files:
file_path = os.path.join(root, fname)
if file_path != arc_name:
zipf.write(file_path, os.path.relpath(file_path, td))
with open(arc_name, "rb") as f:
model_data = f.read()
with open(path, "wb") as f:
cloudpickle.dump((model_data, self._act_params), f)
def save(self, path):
save_variables(path)
def load_act(path):
"""Load act function that was returned by learn function.
Parameters
----------
path: str
path to the act function pickle
Returns
-------
act: ActWrapper
function that takes a batch of observations
and returns actions.
"""
return ActWrapper.load_act(path)
def learn(env,
network,
@@ -187,7 +110,6 @@ def learn(env,
"""
# Create all the functions necessary to train the model
sess = get_session()
set_global_seeds(seed)
q_func = build_q_func(network, **network_kwargs)
@@ -196,26 +118,23 @@ def learn(env,
# by cloudpickle when serializing make_obs_ph
observation_space = env.observation_space
def make_obs_ph(name):
return ObservationInput(observation_space, name=name)
act, train, update_target, debug = deepq.build_train(
make_obs_ph=make_obs_ph,
model = deepq.DEEPQ(
q_func=q_func,
observation_shape=env.observation_space.shape,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=lr),
gamma=gamma,
lr=lr,
grad_norm_clipping=10,
gamma=gamma,
param_noise=param_noise
)
act_params = {
'make_obs_ph': make_obs_ph,
'q_func': q_func,
'num_actions': env.action_space.n,
}
act = ActWrapper(act, act_params)
if load_path is not None:
load_path = osp.expanduser(load_path)
ckpt = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None)
ckpt.restore(manager.latest_checkpoint)
print("Restoring from {}".format(manager.latest_checkpoint))
# Create the replay buffer
if prioritized_replay:
@@ -233,101 +152,83 @@ def learn(env,
initial_p=1.0,
final_p=exploration_final_eps)
# Initialize the parameters and copy them to the target network.
U.initialize()
update_target()
model.update_target()
episode_rewards = [0.0]
saved_mean_reward = None
obs = env.reset()
# always mimic the vectorized env
if not isinstance(env, VecEnv):
obs = np.expand_dims(np.array(obs), axis=0)
reset = True
with tempfile.TemporaryDirectory() as td:
td = checkpoint_path or td
for t in range(total_timesteps):
if callback is not None:
if callback(locals(), globals()):
break
kwargs = {}
if not param_noise:
update_eps = tf.constant(exploration.value(t))
update_param_noise_threshold = 0.
else:
update_eps = tf.constant(0.)
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
# for detailed explanation.
update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(env.action_space.n))
kwargs['reset'] = reset
kwargs['update_param_noise_threshold'] = update_param_noise_threshold
kwargs['update_param_noise_scale'] = True
action, _, _, _ = model.step(tf.constant(obs), update_eps=update_eps, **kwargs)
action = action[0].numpy()
reset = False
new_obs, rew, done, _ = env.step(action)
# Store transition in the replay buffer.
if not isinstance(env, VecEnv):
new_obs = np.expand_dims(np.array(new_obs), axis=0)
replay_buffer.add(obs[0], action, rew, new_obs[0], float(done))
else:
replay_buffer.add(obs[0], action, rew[0], new_obs[0], float(done[0]))
# # Store transition in the replay buffer.
# replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
model_file = os.path.join(td, "model")
model_saved = False
episode_rewards[-1] += rew
if done:
obs = env.reset()
if not isinstance(env, VecEnv):
obs = np.expand_dims(np.array(obs), axis=0)
episode_rewards.append(0.0)
reset = True
if tf.train.latest_checkpoint(td) is not None:
load_variables(model_file)
logger.log('Loaded model from {}'.format(model_file))
model_saved = True
elif load_path is not None:
load_variables(load_path)
logger.log('Loaded model from {}'.format(load_path))
for t in range(total_timesteps):
if callback is not None:
if callback(locals(), globals()):
break
# Take action and update exploration to the newest value
kwargs = {}
if not param_noise:
update_eps = exploration.value(t)
update_param_noise_threshold = 0.
if t > learning_starts and t % train_freq == 0:
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
if prioritized_replay:
experience = replay_buffer.sample(batch_size, beta=beta_schedule.value(t))
(obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
else:
update_eps = 0.
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
# for detailed explanation.
update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(env.action_space.n))
kwargs['reset'] = reset
kwargs['update_param_noise_threshold'] = update_param_noise_threshold
kwargs['update_param_noise_scale'] = True
action = act(np.array(obs)[None], update_eps=update_eps, **kwargs)[0]
env_action = action
reset = False
new_obs, rew, done, _ = env.step(env_action)
# Store transition in the replay buffer.
replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size)
weights, batch_idxes = np.ones_like(rewards), None
obses_t, obses_tp1 = tf.constant(obses_t), tf.constant(obses_tp1)
actions, rewards, dones = tf.constant(actions), tf.constant(rewards), tf.constant(dones)
weights = tf.constant(weights)
td_errors = model.train(obses_t, actions, rewards, obses_tp1, dones, weights)
if prioritized_replay:
new_priorities = np.abs(td_errors) + prioritized_replay_eps
replay_buffer.update_priorities(batch_idxes, new_priorities)
episode_rewards[-1] += rew
if done:
obs = env.reset()
episode_rewards.append(0.0)
reset = True
if t > learning_starts and t % target_network_update_freq == 0:
# Update target network periodically.
model.update_target()
if t > learning_starts and t % train_freq == 0:
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
if prioritized_replay:
experience = replay_buffer.sample(batch_size, beta=beta_schedule.value(t))
(obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
else:
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size)
weights, batch_idxes = np.ones_like(rewards), None
td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights)
if prioritized_replay:
new_priorities = np.abs(td_errors) + prioritized_replay_eps
replay_buffer.update_priorities(batch_idxes, new_priorities)
mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 1)
num_episodes = len(episode_rewards)
if done and print_freq is not None and len(episode_rewards) % print_freq == 0:
logger.record_tabular("steps", t)
logger.record_tabular("episodes", num_episodes)
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
logger.record_tabular("% time spent exploring", int(100 * exploration.value(t)))
logger.dump_tabular()
if t > learning_starts and t % target_network_update_freq == 0:
# Update target network periodically.
update_target()
mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 1)
num_episodes = len(episode_rewards)
if done and print_freq is not None and len(episode_rewards) % print_freq == 0:
logger.record_tabular("steps", t)
logger.record_tabular("episodes", num_episodes)
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
logger.record_tabular("% time spent exploring", int(100 * exploration.value(t)))
logger.dump_tabular()
if (checkpoint_freq is not None and t > learning_starts and
num_episodes > 100 and t % checkpoint_freq == 0):
if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward:
if print_freq is not None:
logger.log("Saving model due to mean reward increase: {} -> {}".format(
saved_mean_reward, mean_100ep_reward))
save_variables(model_file)
model_saved = True
saved_mean_reward = mean_100ep_reward
if model_saved:
if print_freq is not None:
logger.log("Restored model with mean reward: {}".format(saved_mean_reward))
load_variables(model_file)
return act
return model

View File

@@ -0,0 +1,191 @@
"""Deep Q model
The functions in this model:
======= step ========
Function to chose an action given an observation
Parameters
----------
observation: tensor
Observation that can be feed into the output of make_obs_ph
stochastic: bool
if set to False all the actions are always deterministic (default False)
update_eps: float
update epsilon a new value, if negative not update happens
(default: no update)
Returns
-------
Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be performed for
every element of the batch.
(NOT IMPLEMENTED YET)
======= step (in case of parameter noise) ========
Function to chose an action given an observation
Parameters
----------
observation: object
Observation that can be feed into the output of make_obs_ph
stochastic: bool
if set to False all the actions are always deterministic (default False)
update_eps: float
update epsilon to a new value, if negative no update happens
(default: no update)
reset: bool
reset the perturbed policy by sampling a new perturbation
update_param_noise_threshold: float
the desired threshold for the difference between non-perturbed and perturbed policy
update_param_noise_scale: bool
whether or not to update the scale of the noise for the next time it is re-perturbed
Returns
-------
Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be performed for
every element of the batch.
======= train =======
Function that takes a transition (s,a,r,s',d) and optimizes Bellman equation's error:
td_error = Q(s,a) - (r + gamma * (1-d) * max_a' Q(s', a'))
loss = huber_loss[td_error]
Parameters
----------
obs_t: object
a batch of observations
action: np.array
actions that were selected upon seeing obs_t.
dtype must be int32 and shape must be (batch_size,)
reward: np.array
immediate reward attained after executing those actions
dtype must be float32 and shape must be (batch_size,)
obs_tp1: object
observations that followed obs_t
done: np.array
1 if obs_t was the last observation in the episode and 0 otherwise
obs_tp1 gets ignored, but must be of the valid shape.
dtype must be float32 and shape must be (batch_size,)
weight: np.array
imporance weights for every element of the batch (gradient is multiplied
by the importance weight) dtype must be float32 and shape must be (batch_size,)
Returns
-------
td_error: np.array
a list of differences between Q(s,a) and the target in Bellman's equation.
dtype is float32 and shape is (batch_size,)
======= update_target ========
copy the parameters from optimized Q function to the target Q function.
In Q learning we actually optimize the following error:
Q(s,a) - (r + gamma * max_a' Q'(s', a'))
Where Q' is lagging behind Q to stablize the learning. For example for Atari
Q' is set to Q once every 10000 updates training steps.
"""
import tensorflow as tf
@tf.function
def huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(
tf.abs(x) < delta,
tf.square(x) * 0.5,
delta * (tf.abs(x) - 0.5 * delta)
)
class DEEPQ(tf.Module):
def __init__(self, q_func, observation_shape, num_actions, lr, grad_norm_clipping=None, gamma=1.0,
double_q=True, param_noise=False, param_noise_filter_func=None):
self.num_actions = num_actions
self.gamma = gamma
self.double_q = double_q
self.param_noise = param_noise
self.param_noise_filter_func = param_noise_filter_func
self.grad_norm_clipping = grad_norm_clipping
self.optimizer = tf.keras.optimizers.Adam(lr)
with tf.name_scope('q_network'):
self.q_network = q_func(observation_shape, num_actions)
with tf.name_scope('target_q_network'):
self.target_q_network = q_func(observation_shape, num_actions)
self.eps = tf.Variable(0., name="eps")
@tf.function
def step(self, obs, stochastic=True, update_eps=-1):
if self.param_noise:
raise ValueError('not supporting noise yet')
else:
q_values = self.q_network(obs)
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(obs)[0]
random_actions = tf.random.uniform(tf.stack([batch_size]), minval=0, maxval=self.num_actions, dtype=tf.int64)
chose_random = tf.random.uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < self.eps
stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)
if stochastic:
output_actions = stochastic_actions
else:
output_actions = deterministic_actions
if update_eps >= 0:
self.eps.assign(update_eps)
return output_actions, None, None, None
@tf.function()
def train(self, obs0, actions, rewards, obs1, dones, importance_weights):
with tf.GradientTape() as tape:
q_t = self.q_network(obs0)
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(actions, self.num_actions, dtype=tf.float32), 1)
q_tp1 = self.target_q_network(obs1)
if self.double_q:
q_tp1_using_online_net = self.q_network(obs1)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best = tf.reduce_sum(q_tp1 * tf.one_hot(q_tp1_best_using_online_net, self.num_actions, dtype=tf.float32), 1)
else:
q_tp1_best = tf.reduce_max(q_tp1, 1)
dones = tf.cast(dones, q_tp1_best.dtype)
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
q_t_selected_target = rewards + self.gamma * q_tp1_best_masked
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
errors = huber_loss(td_error)
weighted_error = tf.reduce_mean(importance_weights * errors)
grads = tape.gradient(weighted_error, self.q_network.trainable_variables)
if self.grad_norm_clipping:
clipped_grads = []
for grad in grads:
clipped_grads.append(tf.clip_by_norm(grad, self.grad_norm_clipping))
clipped_grads = grads
grads_and_vars = zip(grads, self.q_network.trainable_variables)
self.optimizer.apply_gradients(grads_and_vars)
return td_error
@tf.function(autograph=False)
def update_target(self):
q_vars = self.q_network.trainable_variables
target_q_vars = self.target_q_network.trainable_variables
for var, var_target in zip(q_vars, target_q_vars):
var_target.assign(var)

View File

@@ -1,79 +0,0 @@
import gym
import itertools
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
import baselines.common.tf_util as U
from baselines import logger
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer
from baselines.deepq.utils import ObservationInput
from baselines.common.schedules import LinearSchedule
def model(inpt, num_actions, scope, reuse=False):
"""This model takes as input an observation and returns values of all actions."""
with tf.variable_scope(scope, reuse=reuse):
out = inpt
out = layers.fully_connected(out, num_outputs=64, activation_fn=tf.nn.tanh)
out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
return out
if __name__ == '__main__':
with U.make_session(num_cpu=8):
# Create the environment
env = gym.make("CartPole-v0")
# Create all the functions necessary to train the model
act, train, update_target, debug = deepq.build_train(
make_obs_ph=lambda name: ObservationInput(env.observation_space, name=name),
q_func=model,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=5e-4),
)
# Create the replay buffer
replay_buffer = ReplayBuffer(50000)
# Create the schedule for exploration starting from 1 (every action is random) down to
# 0.02 (98% of actions are selected according to values predicted by the model).
exploration = LinearSchedule(schedule_timesteps=10000, initial_p=1.0, final_p=0.02)
# Initialize the parameters and copy them to the target network.
U.initialize()
update_target()
episode_rewards = [0.0]
obs = env.reset()
for t in itertools.count():
# Take action and update exploration to the newest value
action = act(obs[None], update_eps=exploration.value(t))[0]
new_obs, rew, done, _ = env.step(action)
# Store transition in the replay buffer.
replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
episode_rewards[-1] += rew
if done:
obs = env.reset()
episode_rewards.append(0)
is_solved = t > 100 and np.mean(episode_rewards[-101:-1]) >= 200
if is_solved:
# Show off the result
env.render()
else:
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
if t > 1000:
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(32)
train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
# Update target network periodically.
if t % 1000 == 0:
update_target()
if done and len(episode_rewards) % 10 == 0:
logger.record_tabular("steps", t)
logger.record_tabular("episodes", len(episode_rewards))
logger.record_tabular("mean episode reward", round(np.mean(episode_rewards[-101:-1]), 1))
logger.record_tabular("% time spent exploring", int(100 * exploration.value(t)))
logger.dump_tabular()

View File

@@ -1,21 +0,0 @@
import gym
from baselines import deepq
def main():
env = gym.make("CartPole-v0")
act = deepq.learn(env, network='mlp', total_timesteps=0, load_path="cartpole_model.pkl")
while True:
obs, done = env.reset(), False
episode_rew = 0
while not done:
env.render()
obs, rew, done, _ = env.step(act(obs[None])[0])
episode_rew += rew
print("Episode reward", episode_rew)
if __name__ == '__main__':
main()

View File

@@ -1,27 +0,0 @@
import gym
from baselines import deepq
from baselines.common import models
def main():
env = gym.make("MountainCar-v0")
act = deepq.learn(
env,
network=models.mlp(num_layers=1, num_hidden=64),
total_timesteps=0,
load_path='mountaincar_model.pkl'
)
while True:
obs, done = env.reset(), False
episode_rew = 0
while not done:
env.render()
obs, rew, done, _ = env.step(act(obs[None])[0])
episode_rew += rew
print("Episode reward", episode_rew)
if __name__ == '__main__':
main()

View File

@@ -1,28 +0,0 @@
import gym
from baselines import deepq
def main():
env = gym.make("PongNoFrameskip-v4")
env = deepq.wrap_atari_dqn(env)
model = deepq.learn(
env,
"conv_only",
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
total_timesteps=0
)
while True:
obs, done = env.reset(), False
episode_rew = 0
while not done:
env.render()
obs, rew, done, _ = env.step(model(obs[None])[0])
episode_rew += rew
print("Episode reward", episode_rew)
if __name__ == '__main__':
main()

View File

@@ -1,30 +0,0 @@
import gym
from baselines import deepq
def callback(lcl, _glb):
# stop training if reward exceeds 199
is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
return is_solved
def main():
env = gym.make("CartPole-v0")
act = deepq.learn(
env,
network='mlp',
lr=1e-3,
total_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.02,
print_freq=10,
callback=callback
)
print("Saving model to cartpole_model.pkl")
act.save("cartpole_model.pkl")
if __name__ == '__main__':
main()

View File

@@ -1,26 +0,0 @@
import gym
from baselines import deepq
from baselines.common import models
def main():
env = gym.make("MountainCar-v0")
# Enabling layer_norm here is import for parameter space noise!
act = deepq.learn(
env,
network=models.mlp(num_hidden=64, num_layers=1),
lr=1e-3,
total_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.1,
print_freq=10,
param_noise=True
)
print("Saving model to mountaincar_model.pkl")
act.save("mountaincar_model.pkl")
if __name__ == '__main__':
main()

View File

@@ -1,34 +0,0 @@
from baselines import deepq
from baselines import bench
from baselines import logger
from baselines.common.atari_wrappers import make_atari
def main():
logger.configure()
env = make_atari('PongNoFrameskip-v4')
env = bench.Monitor(env, logger.get_dir())
env = deepq.wrap_atari_dqn(env)
model = deepq.learn(
env,
"conv_only",
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
lr=1e-4,
total_timesteps=int(1e7),
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
)
model.save('pong_model.pkl')
env.close()
if __name__ == '__main__':
main()

View File

@@ -1,5 +1,4 @@
import tensorflow as tf
import tensorflow.contrib.layers as layers
def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
@@ -7,39 +6,42 @@ def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **netwo
from baselines.common.models import get_network_builder
network = get_network_builder(network)(**network_kwargs)
def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
latent = network(input_placeholder)
if isinstance(latent, tuple):
if latent[1] is not None:
raise NotImplementedError("DQN is not compatible with recurrent policies yet")
latent = latent[0]
def q_func_builder(input_shape, num_actions):
# the sub Functional model which does not include the top layer.
model = network(input_shape)
latent = layers.flatten(latent)
# wrapping the sub Functional model with layers that compute action scores into another Functional model.
latent = model.outputs
if len(latent) > 1:
if latent[1] is not None:
raise NotImplementedError("DQN is not compatible with recurrent policies yet")
latent = latent[0]
with tf.variable_scope("action_value"):
action_out = latent
latent = tf.keras.layers.Flatten()(latent)
with tf.name_scope("action_value"):
action_out = latent
for hidden in hiddens:
action_out = tf.keras.layers.Dense(units=hidden, activation=None)(action_out)
if layer_norm:
action_out = tf.keras.layers.LayerNormalization(center=True, scale=True)(action_out)
action_out = tf.nn.relu(action_out)
action_scores = tf.keras.layers.Dense(units=num_actions, activation=None)(action_out)
if dueling:
with tf.name_scope("state_value"):
state_out = latent
for hidden in hiddens:
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
state_out = tf.keras.layers.Dense(units=hidden, activation=None)(state_out)
if layer_norm:
action_out = layers.layer_norm(action_out, center=True, scale=True)
action_out = tf.nn.relu(action_out)
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
if dueling:
with tf.variable_scope("state_value"):
state_out = latent
for hidden in hiddens:
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
state_out = layers.layer_norm(state_out, center=True, scale=True)
state_out = tf.nn.relu(state_out)
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
q_out = state_score + action_scores_centered
else:
q_out = action_scores
return q_out
state_out = tf.keras.layers.LayerNormalization(center=True, scale=True)(state_out)
state_out = tf.nn.relu(state_out)
state_score = tf.keras.layers.Dense(units=1, activation=None)(state_out)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
q_out = state_score + action_scores_centered
else:
q_out = action_scores
return tf.keras.Model(inputs=model.inputs, outputs=[q_out])
return q_func_builder

View File

@@ -32,6 +32,9 @@ class ReplayBuffer(object):
def _encode_sample(self, idxes):
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
data = self._storage[0]
ob_dtype = data[0].dtype
ac_dtype = data[1].dtype
for i in idxes:
data = self._storage[i]
obs_t, action, reward, obs_tp1, done = data
@@ -40,7 +43,7 @@ class ReplayBuffer(object):
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)
return np.array(obses_t, dtype=ob_dtype), np.array(actions, dtype=ac_dtype), np.array(rewards, dtype=np.float32), np.array(obses_tp1, dtype=ob_dtype), np.array(dones, dtype=np.float32)
def sample(self, batch_size):
"""Sample a batch of experiences.
@@ -162,7 +165,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage)) ** (-beta)
weights.append(weight / max_weight)
weights = np.array(weights)
weights = np.array(weights, dtype=np.float32)
encoded_sample = self._encode_sample(idxes)
return tuple(list(encoded_sample) + [weights, idxes])

View File

@@ -1,59 +0,0 @@
from baselines.common.input import observation_input
from baselines.common.tf_util import adjust_shape
# ================================================================
# Placeholders
# ================================================================
class TfInput(object):
def __init__(self, name="(unnamed)"):
"""Generalized Tensorflow placeholder. The main differences are:
- possibly uses multiple placeholders internally and returns multiple values
- can apply light postprocessing to the value feed to placeholder.
"""
self.name = name
def get(self):
"""Return the tf variable(s) representing the possibly postprocessed value
of placeholder(s).
"""
raise NotImplementedError
def make_feed_dict(self, data):
"""Given data input it to the placeholder(s)."""
raise NotImplementedError
class PlaceholderTfInput(TfInput):
def __init__(self, placeholder):
"""Wrapper for regular tensorflow placeholder."""
super().__init__(placeholder.name)
self._placeholder = placeholder
def get(self):
return self._placeholder
def make_feed_dict(self, data):
return {self._placeholder: adjust_shape(self._placeholder, data)}
class ObservationInput(PlaceholderTfInput):
def __init__(self, observation_space, name=None):
"""Creates an input placeholder tailored to a specific observation space
Parameters
----------
observation_space:
observation space of the environment. Should be one of the gym.spaces types
name: str
tensorflow name of the underlying placeholder
"""
inpt, self.processed_inpt = observation_input(observation_space, name=name)
super().__init__(inpt)
def get(self):
return self.processed_inpt

View File

@@ -1,52 +0,0 @@
# Generative Adversarial Imitation Learning (GAIL)
- Original paper: https://arxiv.org/abs/1606.03476
For results benchmarking on MuJoCo, please navigate to [here](result/gail-result.md)
## If you want to train an imitation learning agent
### Step 1: Download expert data
Download the expert data into `./data`, [download link](https://drive.google.com/drive/folders/1h3H4AY_ZBx08hz-Ct0Nxxus-V1melu1U?usp=sharing)
### Step 2: Run GAIL
Run with single rank:
```bash
python -m baselines.gail.run_mujoco
```
Run with multiple ranks:
```bash
mpirun -np 16 python -m baselines.gail.run_mujoco
```
See help (`-h`) for more options.
#### In case you want to run Behavior Cloning (BC)
```bash
python -m baselines.gail.behavior_clone
```
See help (`-h`) for more options.
## Contributing
Bug reports and pull requests are welcome on GitHub at https://github.com/openai/baselines/pulls.
## Maintainers
- Yuan-Hong Liao, andrewliao11_at_gmail_dot_com
- Ryan Julian, ryanjulian_at_gmail_dot_com
## Others
Thanks to the open source:
- @openai/imitation
- @carpedm20/deep-rl-tensorflow

View File

@@ -1,87 +0,0 @@
'''
Reference: https://github.com/openai/imitation
I follow the architecture from the official repository
'''
import tensorflow as tf
import numpy as np
from baselines.common.mpi_running_mean_std import RunningMeanStd
from baselines.common import tf_util as U
def logsigmoid(a):
'''Equivalent to tf.log(tf.sigmoid(a))'''
return -tf.nn.softplus(-a)
""" Reference: https://github.com/openai/imitation/blob/99fbccf3e060b6e6c739bdf209758620fcdefd3c/policyopt/thutil.py#L48-L51"""
def logit_bernoulli_entropy(logits):
ent = (1.-tf.nn.sigmoid(logits))*logits - logsigmoid(logits)
return ent
class TransitionClassifier(object):
def __init__(self, env, hidden_size, entcoeff=0.001, lr_rate=1e-3, scope="adversary"):
self.scope = scope
self.observation_shape = env.observation_space.shape
self.actions_shape = env.action_space.shape
self.input_shape = tuple([o+a for o, a in zip(self.observation_shape, self.actions_shape)])
self.num_actions = env.action_space.shape[0]
self.hidden_size = hidden_size
self.build_ph()
# Build grpah
generator_logits = self.build_graph(self.generator_obs_ph, self.generator_acs_ph, reuse=False)
expert_logits = self.build_graph(self.expert_obs_ph, self.expert_acs_ph, reuse=True)
# Build accuracy
generator_acc = tf.reduce_mean(tf.to_float(tf.nn.sigmoid(generator_logits) < 0.5))
expert_acc = tf.reduce_mean(tf.to_float(tf.nn.sigmoid(expert_logits) > 0.5))
# Build regression loss
# let x = logits, z = targets.
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=generator_logits, labels=tf.zeros_like(generator_logits))
generator_loss = tf.reduce_mean(generator_loss)
expert_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=expert_logits, labels=tf.ones_like(expert_logits))
expert_loss = tf.reduce_mean(expert_loss)
# Build entropy loss
logits = tf.concat([generator_logits, expert_logits], 0)
entropy = tf.reduce_mean(logit_bernoulli_entropy(logits))
entropy_loss = -entcoeff*entropy
# Loss + Accuracy terms
self.losses = [generator_loss, expert_loss, entropy, entropy_loss, generator_acc, expert_acc]
self.loss_name = ["generator_loss", "expert_loss", "entropy", "entropy_loss", "generator_acc", "expert_acc"]
self.total_loss = generator_loss + expert_loss + entropy_loss
# Build Reward for policy
self.reward_op = -tf.log(1-tf.nn.sigmoid(generator_logits)+1e-8)
var_list = self.get_trainable_variables()
self.lossandgrad = U.function([self.generator_obs_ph, self.generator_acs_ph, self.expert_obs_ph, self.expert_acs_ph],
self.losses + [U.flatgrad(self.total_loss, var_list)])
def build_ph(self):
self.generator_obs_ph = tf.placeholder(tf.float32, (None, ) + self.observation_shape, name="observations_ph")
self.generator_acs_ph = tf.placeholder(tf.float32, (None, ) + self.actions_shape, name="actions_ph")
self.expert_obs_ph = tf.placeholder(tf.float32, (None, ) + self.observation_shape, name="expert_observations_ph")
self.expert_acs_ph = tf.placeholder(tf.float32, (None, ) + self.actions_shape, name="expert_actions_ph")
def build_graph(self, obs_ph, acs_ph, reuse=False):
with tf.variable_scope(self.scope):
if reuse:
tf.get_variable_scope().reuse_variables()
with tf.variable_scope("obfilter"):
self.obs_rms = RunningMeanStd(shape=self.observation_shape)
obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std
_input = tf.concat([obs, acs_ph], axis=1) # concatenate the two input -> form a transition
p_h1 = tf.contrib.layers.fully_connected(_input, self.hidden_size, activation_fn=tf.nn.tanh)
p_h2 = tf.contrib.layers.fully_connected(p_h1, self.hidden_size, activation_fn=tf.nn.tanh)
logits = tf.contrib.layers.fully_connected(p_h2, 1, activation_fn=tf.identity)
return logits
def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_reward(self, obs, acs):
sess = tf.get_default_session()
if len(obs.shape) == 1:
obs = np.expand_dims(obs, 0)
if len(acs.shape) == 1:
acs = np.expand_dims(acs, 0)
feed_dict = {self.generator_obs_ph: obs, self.generator_acs_ph: acs}
reward = sess.run(self.reward_op, feed_dict)
return reward

View File

@@ -1,124 +0,0 @@
'''
The code is used to train BC imitator, or pretrained GAIL imitator
'''
import argparse
import tempfile
import os.path as osp
import gym
import logging
from tqdm import tqdm
import tensorflow as tf
from baselines.gail import mlp_policy
from baselines import bench
from baselines import logger
from baselines.common import set_global_seeds, tf_util as U
from baselines.common.misc_util import boolean_flag
from baselines.common.mpi_adam import MpiAdam
from baselines.gail.run_mujoco import runner
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
def argsparser():
parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning")
parser.add_argument('--env_id', help='environment ID', default='Hopper-v1')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz')
parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint')
parser.add_argument('--log_dir', help='the directory to save log file', default='log')
# Mujoco Dataset Configuration
parser.add_argument('--traj_limitation', type=int, default=-1)
# Network Configuration (Using MLP Policy)
parser.add_argument('--policy_hidden_size', type=int, default=100)
# for evaluatation
boolean_flag(parser, 'stochastic_policy', default=False, help='use stochastic/deterministic policy to evaluate')
boolean_flag(parser, 'save_sample', default=False, help='save the trajectories or not')
parser.add_argument('--BC_max_iter', help='Max iteration for training BC', type=int, default=1e5)
return parser.parse_args()
def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
adam_epsilon=1e-5, optim_stepsize=3e-4,
ckpt_dir=None, log_dir=None, task_name=None,
verbose=False):
val_per_iter = int(max_iters/10)
ob_space = env.observation_space
ac_space = env.action_space
pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
# placeholder
ob = U.get_placeholder_cached(name="ob")
ac = pi.pdtype.sample_placeholder([None])
stochastic = U.get_placeholder_cached(name="stochastic")
loss = tf.reduce_mean(tf.square(ac-pi.ac))
var_list = pi.get_trainable_variables()
adam = MpiAdam(var_list, epsilon=adam_epsilon)
lossandgrad = U.function([ob, ac, stochastic], [loss]+[U.flatgrad(loss, var_list)])
U.initialize()
adam.sync()
logger.log("Pretraining with Behavior Cloning...")
for iter_so_far in tqdm(range(int(max_iters))):
ob_expert, ac_expert = dataset.get_next_batch(optim_batch_size, 'train')
train_loss, g = lossandgrad(ob_expert, ac_expert, True)
adam.update(g, optim_stepsize)
if verbose and iter_so_far % val_per_iter == 0:
ob_expert, ac_expert = dataset.get_next_batch(-1, 'val')
val_loss, _ = lossandgrad(ob_expert, ac_expert, True)
logger.log("Training loss: {}, Validation loss: {}".format(train_loss, val_loss))
if ckpt_dir is None:
savedir_fname = tempfile.TemporaryDirectory().name
else:
savedir_fname = osp.join(ckpt_dir, task_name)
U.save_state(savedir_fname, var_list=pi.get_variables())
return savedir_fname
def get_task_name(args):
task_name = 'BC'
task_name += '.{}'.format(args.env_id.split("-")[0])
task_name += '.traj_limitation_{}'.format(args.traj_limitation)
task_name += ".seed_{}".format(args.seed)
return task_name
def main(args):
U.make_session(num_cpu=1).__enter__()
set_global_seeds(args.seed)
env = gym.make(args.env_id)
def policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), "monitor.json"))
env.seed(args.seed)
gym.logger.setLevel(logging.WARN)
task_name = get_task_name(args)
args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
args.log_dir = osp.join(args.log_dir, task_name)
dataset = Mujoco_Dset(expert_path=args.expert_path, traj_limitation=args.traj_limitation)
savedir_fname = learn(env,
policy_fn,
dataset,
max_iters=args.BC_max_iter,
ckpt_dir=args.checkpoint_dir,
log_dir=args.log_dir,
task_name=task_name,
verbose=True)
avg_len, avg_ret = runner(env,
policy_fn,
savedir_fname,
timesteps_per_batch=1024,
number_trajs=10,
stochastic_policy=args.stochastic_policy,
save=args.save_sample,
reuse=True)
if __name__ == '__main__':
args = argsparser()
main(args)

View File

@@ -1,114 +0,0 @@
'''
Data structure of the input .npz:
the data is save in python dictionary format with keys: 'acs', 'ep_rets', 'rews', 'obs'
the values of each item is a list storing the expert trajectory sequentially
a transition can be: (data['obs'][t], data['acs'][t], data['obs'][t+1]) and get reward data['rews'][t]
'''
from baselines import logger
import numpy as np
class Dset(object):
def __init__(self, inputs, labels, randomize):
self.inputs = inputs
self.labels = labels
assert len(self.inputs) == len(self.labels)
self.randomize = randomize
self.num_pairs = len(inputs)
self.init_pointer()
def init_pointer(self):
self.pointer = 0
if self.randomize:
idx = np.arange(self.num_pairs)
np.random.shuffle(idx)
self.inputs = self.inputs[idx, :]
self.labels = self.labels[idx, :]
def get_next_batch(self, batch_size):
# if batch_size is negative -> return all
if batch_size < 0:
return self.inputs, self.labels
if self.pointer + batch_size >= self.num_pairs:
self.init_pointer()
end = self.pointer + batch_size
inputs = self.inputs[self.pointer:end, :]
labels = self.labels[self.pointer:end, :]
self.pointer = end
return inputs, labels
class Mujoco_Dset(object):
def __init__(self, expert_path, train_fraction=0.7, traj_limitation=-1, randomize=True):
traj_data = np.load(expert_path)
if traj_limitation < 0:
traj_limitation = len(traj_data['obs'])
obs = traj_data['obs'][:traj_limitation]
acs = traj_data['acs'][:traj_limitation]
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
# and S is the environment observation/action space.
# Flatten to (N * L, prod(S))
if len(obs.shape) > 2:
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
else:
self.obs = np.vstack(obs)
self.acs = np.vstack(acs)
self.rets = traj_data['ep_rets'][:traj_limitation]
self.avg_ret = sum(self.rets)/len(self.rets)
self.std_ret = np.std(np.array(self.rets))
if len(self.acs) > 2:
self.acs = np.squeeze(self.acs)
assert len(self.obs) == len(self.acs)
self.num_traj = min(traj_limitation, len(traj_data['obs']))
self.num_transition = len(self.obs)
self.randomize = randomize
self.dset = Dset(self.obs, self.acs, self.randomize)
# for behavior cloning
self.train_set = Dset(self.obs[:int(self.num_transition*train_fraction), :],
self.acs[:int(self.num_transition*train_fraction), :],
self.randomize)
self.val_set = Dset(self.obs[int(self.num_transition*train_fraction):, :],
self.acs[int(self.num_transition*train_fraction):, :],
self.randomize)
self.log_info()
def log_info(self):
logger.log("Total trajectorues: %d" % self.num_traj)
logger.log("Total transitions: %d" % self.num_transition)
logger.log("Average returns: %f" % self.avg_ret)
logger.log("Std for returns: %f" % self.std_ret)
def get_next_batch(self, batch_size, split=None):
if split is None:
return self.dset.get_next_batch(batch_size)
elif split == 'train':
return self.train_set.get_next_batch(batch_size)
elif split == 'val':
return self.val_set.get_next_batch(batch_size)
else:
raise NotImplementedError
def plot(self):
import matplotlib.pyplot as plt
plt.hist(self.rets)
plt.savefig("histogram_rets.png")
plt.close()
def test(expert_path, traj_limitation, plot):
dset = Mujoco_Dset(expert_path, traj_limitation=traj_limitation)
if plot:
dset.plot()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--expert_path", type=str, default="../data/deterministic.trpo.Hopper.0.00.npz")
parser.add_argument("--traj_limitation", type=int, default=None)
parser.add_argument("--plot", type=bool, default=False)
args = parser.parse_args()
test(args.expert_path, args.traj_limitation, args.plot)

View File

@@ -1,147 +0,0 @@
'''
This code is used to evalaute the imitators trained with different number of trajectories
and plot the results in the same figure for easy comparison.
'''
import argparse
import os
import glob
import gym
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from baselines.gail import run_mujoco
from baselines.gail import mlp_policy
from baselines.common import set_global_seeds, tf_util as U
from baselines.common.misc_util import boolean_flag
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
plt.style.use('ggplot')
CONFIG = {
'traj_limitation': [1, 5, 10, 50],
}
def load_dataset(expert_path):
dataset = Mujoco_Dset(expert_path=expert_path)
return dataset
def argsparser():
parser = argparse.ArgumentParser('Do evaluation')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--policy_hidden_size', type=int, default=100)
parser.add_argument('--env', type=str, choices=['Hopper', 'Walker2d', 'HalfCheetah',
'Humanoid', 'HumanoidStandup'])
boolean_flag(parser, 'stochastic_policy', default=False, help='use stochastic/deterministic policy to evaluate')
return parser.parse_args()
def evaluate_env(env_name, seed, policy_hidden_size, stochastic, reuse, prefix):
def get_checkpoint_dir(checkpoint_list, limit, prefix):
for checkpoint in checkpoint_list:
if ('limitation_'+str(limit) in checkpoint) and (prefix in checkpoint):
return checkpoint
return None
def policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
reuse=reuse, hid_size=policy_hidden_size, num_hid_layers=2)
data_path = os.path.join('data', 'deterministic.trpo.' + env_name + '.0.00.npz')
dataset = load_dataset(data_path)
checkpoint_list = glob.glob(os.path.join('checkpoint', '*' + env_name + ".*"))
log = {
'traj_limitation': [],
'upper_bound': [],
'avg_ret': [],
'avg_len': [],
'normalized_ret': []
}
for i, limit in enumerate(CONFIG['traj_limitation']):
# Do one evaluation
upper_bound = sum(dataset.rets[:limit])/limit
checkpoint_dir = get_checkpoint_dir(checkpoint_list, limit, prefix=prefix)
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
env = gym.make(env_name + '-v1')
env.seed(seed)
print('Trajectory limitation: {}, Load checkpoint: {}, '.format(limit, checkpoint_path))
avg_len, avg_ret = run_mujoco.runner(env,
policy_fn,
checkpoint_path,
timesteps_per_batch=1024,
number_trajs=10,
stochastic_policy=stochastic,
reuse=((i != 0) or reuse))
normalized_ret = avg_ret/upper_bound
print('Upper bound: {}, evaluation returns: {}, normalized scores: {}'.format(
upper_bound, avg_ret, normalized_ret))
log['traj_limitation'].append(limit)
log['upper_bound'].append(upper_bound)
log['avg_ret'].append(avg_ret)
log['avg_len'].append(avg_len)
log['normalized_ret'].append(normalized_ret)
env.close()
return log
def plot(env_name, bc_log, gail_log, stochastic):
upper_bound = bc_log['upper_bound']
bc_avg_ret = bc_log['avg_ret']
gail_avg_ret = gail_log['avg_ret']
plt.plot(CONFIG['traj_limitation'], upper_bound)
plt.plot(CONFIG['traj_limitation'], bc_avg_ret)
plt.plot(CONFIG['traj_limitation'], gail_avg_ret)
plt.xlabel('Number of expert trajectories')
plt.ylabel('Accumulated reward')
plt.title('{} unnormalized scores'.format(env_name))
plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right')
plt.grid(b=True, which='major', color='gray', linestyle='--')
if stochastic:
title_name = 'result/{}-unnormalized-stochastic-scores.png'.format(env_name)
else:
title_name = 'result/{}-unnormalized-deterministic-scores.png'.format(env_name)
plt.savefig(title_name)
plt.close()
bc_normalized_ret = bc_log['normalized_ret']
gail_normalized_ret = gail_log['normalized_ret']
plt.plot(CONFIG['traj_limitation'], np.ones(len(CONFIG['traj_limitation'])))
plt.plot(CONFIG['traj_limitation'], bc_normalized_ret)
plt.plot(CONFIG['traj_limitation'], gail_normalized_ret)
plt.xlabel('Number of expert trajectories')
plt.ylabel('Normalized performance')
plt.title('{} normalized scores'.format(env_name))
plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right')
plt.grid(b=True, which='major', color='gray', linestyle='--')
if stochastic:
title_name = 'result/{}-normalized-stochastic-scores.png'.format(env_name)
else:
title_name = 'result/{}-normalized-deterministic-scores.png'.format(env_name)
plt.ylim(0, 1.6)
plt.savefig(title_name)
plt.close()
def main(args):
U.make_session(num_cpu=1).__enter__()
set_global_seeds(args.seed)
print('Evaluating {}'.format(args.env))
bc_log = evaluate_env(args.env, args.seed, args.policy_hidden_size,
args.stochastic_policy, False, 'BC')
print('Evaluation for {}'.format(args.env))
print(bc_log)
gail_log = evaluate_env(args.env, args.seed, args.policy_hidden_size,
args.stochastic_policy, True, 'gail')
print('Evaluation for {}'.format(args.env))
print(gail_log)
plot(args.env, bc_log, gail_log, args.stochastic_policy)
if __name__ == '__main__':
args = argsparser()
main(args)

View File

@@ -1,75 +0,0 @@
'''
from baselines/ppo1/mlp_policy.py and add simple modification
(1) add reuse argument
(2) cache the `stochastic` placeholder
'''
import tensorflow as tf
import gym
import baselines.common.tf_util as U
from baselines.common.mpi_running_mean_std import RunningMeanStd
from baselines.common.distributions import make_pdtype
from baselines.acktr.utils import dense
class MlpPolicy(object):
recurrent = False
def __init__(self, name, reuse=False, *args, **kwargs):
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
self._init(*args, **kwargs)
self.scope = tf.get_variable_scope().name
def _init(self, ob_space, ac_space, hid_size, num_hid_layers, gaussian_fixed_var=True):
assert isinstance(ob_space, gym.spaces.Box)
self.pdtype = pdtype = make_pdtype(ac_space)
sequence_length = None
ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
with tf.variable_scope("obfilter"):
self.ob_rms = RunningMeanStd(shape=ob_space.shape)
obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
last_out = obz
for i in range(num_hid_layers):
last_out = tf.nn.tanh(dense(last_out, hid_size, "vffc%i" % (i+1), weight_init=U.normc_initializer(1.0)))
self.vpred = dense(last_out, 1, "vffinal", weight_init=U.normc_initializer(1.0))[:, 0]
last_out = obz
for i in range(num_hid_layers):
last_out = tf.nn.tanh(dense(last_out, hid_size, "polfc%i" % (i+1), weight_init=U.normc_initializer(1.0)))
if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
mean = dense(last_out, pdtype.param_shape()[0]//2, "polfinal", U.normc_initializer(0.01))
logstd = tf.get_variable(name="logstd", shape=[1, pdtype.param_shape()[0]//2], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
else:
pdparam = dense(last_out, pdtype.param_shape()[0], "polfinal", U.normc_initializer(0.01))
self.pd = pdtype.pdfromflat(pdparam)
self.state_in = []
self.state_out = []
# change for BC
stochastic = U.get_placeholder(name="stochastic", dtype=tf.bool, shape=())
ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
self.ac = ac
self._act = U.function([stochastic, ob], [ac, self.vpred])
def act(self, stochastic, ob):
ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0]
def get_variables(self):
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self):
return []

Binary file not shown.

Before

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

Some files were not shown because too many files have changed in this diff Show More