diff --git a/baselines/ppo2/__init__.py b/baselines/ppo2/__init__.py index b2832d9..ecb4342 100644 --- a/baselines/ppo2/__init__.py +++ b/baselines/ppo2/__init__.py @@ -1 +1 @@ -from baselines.ppo2.layers import * \ No newline at end of file +from baselines.ppo2.layers import ppo_lstm, ppo_cnn_lstm diff --git a/baselines/ppo2/layers.py b/baselines/ppo2/layers.py index 2eb0502..eddcd43 100644 --- a/baselines/ppo2/layers.py +++ b/baselines/ppo2/layers.py @@ -7,25 +7,6 @@ from baselines.common.models import register @register("ppo_lstm") def ppo_lstm(nlstm=128, layer_norm=False): - """ - Builds LSTM (Long-Short Term Memory) network to be used in a policy. - Note that the resulting function returns not only the output of the LSTM - (i.e. hidden state of lstm for each step in the sequence), but also a dictionary - with auxiliary tensors to be set as policy attributes. - - Specifically, - S is a placeholder to feed current state (LSTM state has to be managed outside policy) - M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too) - initial_state is a numpy array containing initial lstm state (usually zeros) - state is the output LSTM state (to be fed into S at the next call) - - nlstm: int. LSTM hidden state size - layer_norm: bool. if True, layer-normalized version of LSTM is used - - Returns: - function that builds LSTM with a given input tensor / placeholder - """ - def network_fn(input, mask): memory_size = nlstm * 2 nbatch = input.shape[0] @@ -52,43 +33,49 @@ def ppo_lstm(nlstm=128, layer_norm=False): @register("ppo_cnn_lstm") -def ppo_cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): - def network_fn(X, mask, nenv=1): - nbatch = X.shape[0] - # mask = tf.placeholder(tf.float32, [nbatch], name='mask') - mask = tf.to_float(mask) +def ppo_cnn_lstm(nlstm=128, layer_norm=False): + def network_fn(input, mask): + memory_size = nlstm * 2 + nbatch = input.shape[0] mask.get_shape().assert_is_compatible_with([nbatch]) + state = tf.Variable(np.zeros([nbatch, memory_size]), + name='state', + trainable=False, + dtype=tf.float32, + collections=[tf.GraphKeys.LOCAL_VARIABLES]) - state = tf.placeholder(tf.float32, [nbatch, 2 * nlstm], name='state') + def _network_fn(input, mask, state): + mask = tf.to_float(mask) + initializer = ortho_init(np.sqrt(2)) - init = tf.constant_initializer(np.sqrt(2)) + h = tf.contrib.layers.conv2d(input, + num_outputs=32, + kernel_size=8, + stride=4, + padding="VALID", + weights_initializer=initializer) + h = tf.contrib.layers.conv2d(h, + num_outputs=64, + kernel_size=4, + stride=2, + padding="VALID", + weights_initializer=initializer) + h = tf.contrib.layers.conv2d(h, + num_outputs=64, + kernel_size=3, + stride=1, + padding="VALID", + weights_initializer=initializer) + h = tf.layers.flatten(h) + h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer) - h = tf.contrib.layers.conv2d(X, - num_outputs=32, - kernel_size=8, - stride=4, - padding="VALID", - weights_initializer=init) - h2 = tf.contrib.layers.conv2d(h, - num_outputs=64, - kernel_size=4, - stride=2, - padding="VALID", - weights_initializer=init) - h3 = tf.contrib.layers.conv2d(h2, - num_outputs=64, - kernel_size=3, - stride=1, - padding="VALID", - weights_initializer=init) - X = tf.layers.flatten(h3) - X = tf.layers.dense(X, units=512, activation=tf.nn.relu, kernel_initializer=init) + if layer_norm: + h, next_state = lnlstm(h, mask, state, scope='lnlstm', nh=nlstm) + else: + h, next_state = lstm(h, mask, state, scope='lstm', nh=nlstm) + return h, next_state - h, snew = lstm(X, mask, state, scope='lstm', nh=nlstm) - initial_state = np.zeros(state.shape.as_list(), dtype=float) - - return h, {'prev': {'state': state, 'mask': mask}, - 'post': {'state': snew}, } + return state, _network_fn return network_fn