name the memory variable of PPO RNNs more describly
This commit is contained in:
@@ -12,7 +12,7 @@ def ppo_lstm(nlstm=128, layer_norm=False):
|
||||
nbatch = input.shape[0]
|
||||
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||
name='state',
|
||||
name='lstm_state',
|
||||
trainable=False,
|
||||
dtype=tf.float32,
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
@@ -39,7 +39,7 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
|
||||
nbatch = input.shape[0]
|
||||
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||
name='state',
|
||||
name='lstm_state',
|
||||
trainable=False,
|
||||
dtype=tf.float32,
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
|
Reference in New Issue
Block a user