name the memory variable of PPO RNNs more describly

This commit is contained in:
gyunt
2019-03-23 05:25:49 +09:00
parent 06cef53de3
commit 43a86980ea

View File

@@ -12,7 +12,7 @@ def ppo_lstm(nlstm=128, layer_norm=False):
nbatch = input.shape[0]
mask.get_shape().assert_is_compatible_with([nbatch])
state = tf.Variable(np.zeros([nbatch, memory_size]),
name='state',
name='lstm_state',
trainable=False,
dtype=tf.float32,
collections=[tf.GraphKeys.LOCAL_VARIABLES])
@@ -39,7 +39,7 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
nbatch = input.shape[0]
mask.get_shape().assert_is_compatible_with([nbatch])
state = tf.Variable(np.zeros([nbatch, memory_size]),
name='state',
name='lstm_state',
trainable=False,
dtype=tf.float32,
collections=[tf.GraphKeys.LOCAL_VARIABLES])