name the memory variable of PPO RNNs more describly
This commit is contained in:
@@ -12,7 +12,7 @@ def ppo_lstm(nlstm=128, layer_norm=False):
|
|||||||
nbatch = input.shape[0]
|
nbatch = input.shape[0]
|
||||||
mask.get_shape().assert_is_compatible_with([nbatch])
|
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||||
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||||
name='state',
|
name='lstm_state',
|
||||||
trainable=False,
|
trainable=False,
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
@@ -39,7 +39,7 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
|
|||||||
nbatch = input.shape[0]
|
nbatch = input.shape[0]
|
||||||
mask.get_shape().assert_is_compatible_with([nbatch])
|
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||||
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||||
name='state',
|
name='lstm_state',
|
||||||
trainable=False,
|
trainable=False,
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
Reference in New Issue
Block a user