improve lstm code.

This commit is contained in:
gyunt
2019-03-21 23:34:40 +09:00
parent 7d3cba70a9
commit f8d22815cd
2 changed files with 39 additions and 52 deletions

View File

@@ -1 +1 @@
from baselines.ppo2.layers import *
from baselines.ppo2.layers import ppo_lstm, ppo_cnn_lstm

View File

@@ -7,25 +7,6 @@ from baselines.common.models import register
@register("ppo_lstm")
def ppo_lstm(nlstm=128, layer_norm=False):
"""
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
Note that the resulting function returns not only the output of the LSTM
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
with auxiliary tensors to be set as policy attributes.
Specifically,
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
initial_state is a numpy array containing initial lstm state (usually zeros)
state is the output LSTM state (to be fed into S at the next call)
nlstm: int. LSTM hidden state size
layer_norm: bool. if True, layer-normalized version of LSTM is used
Returns:
function that builds LSTM with a given input tensor / placeholder
"""
def network_fn(input, mask):
memory_size = nlstm * 2
nbatch = input.shape[0]
@@ -52,43 +33,49 @@ def ppo_lstm(nlstm=128, layer_norm=False):
@register("ppo_cnn_lstm")
def ppo_cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(X, mask, nenv=1):
nbatch = X.shape[0]
# mask = tf.placeholder(tf.float32, [nbatch], name='mask')
mask = tf.to_float(mask)
def ppo_cnn_lstm(nlstm=128, layer_norm=False):
def network_fn(input, mask):
memory_size = nlstm * 2
nbatch = input.shape[0]
mask.get_shape().assert_is_compatible_with([nbatch])
state = tf.Variable(np.zeros([nbatch, memory_size]),
name='state',
trainable=False,
dtype=tf.float32,
collections=[tf.GraphKeys.LOCAL_VARIABLES])
state = tf.placeholder(tf.float32, [nbatch, 2 * nlstm], name='state')
def _network_fn(input, mask, state):
mask = tf.to_float(mask)
initializer = ortho_init(np.sqrt(2))
init = tf.constant_initializer(np.sqrt(2))
h = tf.contrib.layers.conv2d(input,
num_outputs=32,
kernel_size=8,
stride=4,
padding="VALID",
weights_initializer=initializer)
h = tf.contrib.layers.conv2d(h,
num_outputs=64,
kernel_size=4,
stride=2,
padding="VALID",
weights_initializer=initializer)
h = tf.contrib.layers.conv2d(h,
num_outputs=64,
kernel_size=3,
stride=1,
padding="VALID",
weights_initializer=initializer)
h = tf.layers.flatten(h)
h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer)
h = tf.contrib.layers.conv2d(X,
num_outputs=32,
kernel_size=8,
stride=4,
padding="VALID",
weights_initializer=init)
h2 = tf.contrib.layers.conv2d(h,
num_outputs=64,
kernel_size=4,
stride=2,
padding="VALID",
weights_initializer=init)
h3 = tf.contrib.layers.conv2d(h2,
num_outputs=64,
kernel_size=3,
stride=1,
padding="VALID",
weights_initializer=init)
X = tf.layers.flatten(h3)
X = tf.layers.dense(X, units=512, activation=tf.nn.relu, kernel_initializer=init)
if layer_norm:
h, next_state = lnlstm(h, mask, state, scope='lnlstm', nh=nlstm)
else:
h, next_state = lstm(h, mask, state, scope='lstm', nh=nlstm)
return h, next_state
h, snew = lstm(X, mask, state, scope='lstm', nh=nlstm)
initial_state = np.zeros(state.shape.as_list(), dtype=float)
return h, {'prev': {'state': state, 'mask': mask},
'post': {'state': snew}, }
return state, _network_fn
return network_fn