From 2a4ba2b0a5b607d608e3f0a22779cf8b51b46cd9 Mon Sep 17 00:00:00 2001 From: gyunt Date: Wed, 27 Mar 2019 07:54:15 +0900 Subject: [PATCH] add RNN layers. --- baselines/ppo2/layers.py | 105 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/baselines/ppo2/layers.py b/baselines/ppo2/layers.py index 90d42ec..04c5e0d 100644 --- a/baselines/ppo2/layers.py +++ b/baselines/ppo2/layers.py @@ -1,7 +1,7 @@ import numpy as np import tensorflow as tf -from baselines.a2c.utils import ortho_init +from baselines.a2c.utils import ortho_init, fc from baselines.common.models import register @@ -88,6 +88,93 @@ def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs): return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) +@register("ppo_gru") +def ppo_gru(nlstm=128): + def network_fn(input, mask): + memory_size = nlstm + nbatch = input.shape[0] + mask.get_shape().assert_is_compatible_with([nbatch]) + state = tf.Variable(np.zeros([nbatch, memory_size]), + name='gru_state', + trainable=False, + dtype=tf.float32, + collections=[tf.GraphKeys.LOCAL_VARIABLES]) + + def _network_fn(input, mask, state): + input = tf.layers.flatten(input) + mask = tf.to_float(mask) + + h, next_state = gru(input, mask, state, nh=nlstm) + return h, next_state + + return state, _network_fn + + return network_fn + + +@register("ppo_lstm_mlp") +def ppo_lstm(nlstm=128, layer_norm=False): + def network_fn(input, mask): + memory_size = nlstm * 2 + nbatch = input.shape[0] + mask.get_shape().assert_is_compatible_with([nbatch]) + state = tf.Variable(np.zeros([nbatch, memory_size]), + name='lstm_state', + trainable=False, + dtype=tf.float32, + collections=[tf.GraphKeys.LOCAL_VARIABLES]) + + def _network_fn(input, mask, state): + input = tf.layers.flatten(input) + mask = tf.to_float(mask) + + h, next_state = lstm(input, mask, state, scope='lstm', nh=nlstm) + + num_layers = 2 + num_hidden = 64 + activation = tf.nn.relu + for i in range(num_layers): + h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2)) + h = activation(h) + return h, next_state + + return state, _network_fn + + return network_fn + + +@register("ppo_gru_mlp") +def ppo_gru_mlp(nlstm=128): + def network_fn(input, mask): + memory_size = nlstm + nbatch = input.shape[0] + mask.get_shape().assert_is_compatible_with([nbatch]) + state = tf.Variable(np.zeros([nbatch, memory_size]), + name='gru_state', + trainable=False, + dtype=tf.float32, + collections=[tf.GraphKeys.LOCAL_VARIABLES]) + + def _network_fn(input, mask, state): + input = tf.layers.flatten(input) + mask = tf.to_float(mask) + + h, next_state = gru(input, mask, state, nh=nlstm) + + num_layers = 2 + num_hidden = 64 + activation = tf.nn.relu + for i in range(num_layers): + h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2)) + h = activation(h) + + return h, next_state + + return state, _network_fn + + return network_fn + + def lstm(x, m, s, scope, nh, init_scale=1.0): x = tf.layers.flatten(x) nin = x.get_shape()[1] @@ -155,3 +242,19 @@ def lnlstm(x, m, s, scope, nh, init_scale=1.0): s = tf.concat(axis=1, values=[c, h]) return h, s + + +def gru(x, mask, state, nh, init_scale=-1.0): + """Gated recurrent unit (GRU) with nunits cells.""" + h = state + mask = tf.tile(tf.expand_dims(mask, axis=-1), multiples=[1, nh]) + + h *= (1.0 - mask) + hx = tf.concat([h, x], axis=1) + mr = tf.sigmoid(fc(hx, nh=nh * 2, scope='gru_mr', init_bias=init_scale)) + # r: read strength. m: 'member strength + m, r = tf.split(mr, 2, axis=1) + rh_x = tf.concat([r * h, x], axis=1) + htil = tf.tanh(fc(rh_x, nh=nh, scope='gru_htil')) + h = m * h + (1.0 - m) * htil + return h, h