improve lstm code.
This commit is contained in:
@@ -1 +1 @@
|
||||
from baselines.ppo2.layers import *
|
||||
from baselines.ppo2.layers import ppo_lstm, ppo_cnn_lstm
|
||||
|
@@ -7,25 +7,6 @@ from baselines.common.models import register
|
||||
|
||||
@register("ppo_lstm")
|
||||
def ppo_lstm(nlstm=128, layer_norm=False):
|
||||
"""
|
||||
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
|
||||
Note that the resulting function returns not only the output of the LSTM
|
||||
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
|
||||
with auxiliary tensors to be set as policy attributes.
|
||||
|
||||
Specifically,
|
||||
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
|
||||
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
|
||||
initial_state is a numpy array containing initial lstm state (usually zeros)
|
||||
state is the output LSTM state (to be fed into S at the next call)
|
||||
|
||||
nlstm: int. LSTM hidden state size
|
||||
layer_norm: bool. if True, layer-normalized version of LSTM is used
|
||||
|
||||
Returns:
|
||||
function that builds LSTM with a given input tensor / placeholder
|
||||
"""
|
||||
|
||||
def network_fn(input, mask):
|
||||
memory_size = nlstm * 2
|
||||
nbatch = input.shape[0]
|
||||
@@ -52,43 +33,49 @@ def ppo_lstm(nlstm=128, layer_norm=False):
|
||||
|
||||
|
||||
@register("ppo_cnn_lstm")
|
||||
def ppo_cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||
def network_fn(X, mask, nenv=1):
|
||||
nbatch = X.shape[0]
|
||||
# mask = tf.placeholder(tf.float32, [nbatch], name='mask')
|
||||
mask = tf.to_float(mask)
|
||||
def ppo_cnn_lstm(nlstm=128, layer_norm=False):
|
||||
def network_fn(input, mask):
|
||||
memory_size = nlstm * 2
|
||||
nbatch = input.shape[0]
|
||||
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||
name='state',
|
||||
trainable=False,
|
||||
dtype=tf.float32,
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
|
||||
state = tf.placeholder(tf.float32, [nbatch, 2 * nlstm], name='state')
|
||||
def _network_fn(input, mask, state):
|
||||
mask = tf.to_float(mask)
|
||||
initializer = ortho_init(np.sqrt(2))
|
||||
|
||||
init = tf.constant_initializer(np.sqrt(2))
|
||||
h = tf.contrib.layers.conv2d(input,
|
||||
num_outputs=32,
|
||||
kernel_size=8,
|
||||
stride=4,
|
||||
padding="VALID",
|
||||
weights_initializer=initializer)
|
||||
h = tf.contrib.layers.conv2d(h,
|
||||
num_outputs=64,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding="VALID",
|
||||
weights_initializer=initializer)
|
||||
h = tf.contrib.layers.conv2d(h,
|
||||
num_outputs=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
weights_initializer=initializer)
|
||||
h = tf.layers.flatten(h)
|
||||
h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer)
|
||||
|
||||
h = tf.contrib.layers.conv2d(X,
|
||||
num_outputs=32,
|
||||
kernel_size=8,
|
||||
stride=4,
|
||||
padding="VALID",
|
||||
weights_initializer=init)
|
||||
h2 = tf.contrib.layers.conv2d(h,
|
||||
num_outputs=64,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding="VALID",
|
||||
weights_initializer=init)
|
||||
h3 = tf.contrib.layers.conv2d(h2,
|
||||
num_outputs=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
weights_initializer=init)
|
||||
X = tf.layers.flatten(h3)
|
||||
X = tf.layers.dense(X, units=512, activation=tf.nn.relu, kernel_initializer=init)
|
||||
if layer_norm:
|
||||
h, next_state = lnlstm(h, mask, state, scope='lnlstm', nh=nlstm)
|
||||
else:
|
||||
h, next_state = lstm(h, mask, state, scope='lstm', nh=nlstm)
|
||||
return h, next_state
|
||||
|
||||
h, snew = lstm(X, mask, state, scope='lstm', nh=nlstm)
|
||||
initial_state = np.zeros(state.shape.as_list(), dtype=float)
|
||||
|
||||
return h, {'prev': {'state': state, 'mask': mask},
|
||||
'post': {'state': snew}, }
|
||||
return state, _network_fn
|
||||
|
||||
return network_fn
|
||||
|
||||
|
Reference in New Issue
Block a user