add RNN layers.
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from baselines.a2c.utils import ortho_init
|
from baselines.a2c.utils import ortho_init, fc
|
||||||
from baselines.common.models import register
|
from baselines.common.models import register
|
||||||
|
|
||||||
|
|
||||||
@@ -88,6 +88,93 @@ def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs):
|
|||||||
return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register("ppo_gru")
|
||||||
|
def ppo_gru(nlstm=128):
|
||||||
|
def network_fn(input, mask):
|
||||||
|
memory_size = nlstm
|
||||||
|
nbatch = input.shape[0]
|
||||||
|
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||||
|
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||||
|
name='gru_state',
|
||||||
|
trainable=False,
|
||||||
|
dtype=tf.float32,
|
||||||
|
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
|
||||||
|
def _network_fn(input, mask, state):
|
||||||
|
input = tf.layers.flatten(input)
|
||||||
|
mask = tf.to_float(mask)
|
||||||
|
|
||||||
|
h, next_state = gru(input, mask, state, nh=nlstm)
|
||||||
|
return h, next_state
|
||||||
|
|
||||||
|
return state, _network_fn
|
||||||
|
|
||||||
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
@register("ppo_lstm_mlp")
|
||||||
|
def ppo_lstm(nlstm=128, layer_norm=False):
|
||||||
|
def network_fn(input, mask):
|
||||||
|
memory_size = nlstm * 2
|
||||||
|
nbatch = input.shape[0]
|
||||||
|
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||||
|
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||||
|
name='lstm_state',
|
||||||
|
trainable=False,
|
||||||
|
dtype=tf.float32,
|
||||||
|
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
|
||||||
|
def _network_fn(input, mask, state):
|
||||||
|
input = tf.layers.flatten(input)
|
||||||
|
mask = tf.to_float(mask)
|
||||||
|
|
||||||
|
h, next_state = lstm(input, mask, state, scope='lstm', nh=nlstm)
|
||||||
|
|
||||||
|
num_layers = 2
|
||||||
|
num_hidden = 64
|
||||||
|
activation = tf.nn.relu
|
||||||
|
for i in range(num_layers):
|
||||||
|
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
||||||
|
h = activation(h)
|
||||||
|
return h, next_state
|
||||||
|
|
||||||
|
return state, _network_fn
|
||||||
|
|
||||||
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
@register("ppo_gru_mlp")
|
||||||
|
def ppo_gru_mlp(nlstm=128):
|
||||||
|
def network_fn(input, mask):
|
||||||
|
memory_size = nlstm
|
||||||
|
nbatch = input.shape[0]
|
||||||
|
mask.get_shape().assert_is_compatible_with([nbatch])
|
||||||
|
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
||||||
|
name='gru_state',
|
||||||
|
trainable=False,
|
||||||
|
dtype=tf.float32,
|
||||||
|
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
|
||||||
|
def _network_fn(input, mask, state):
|
||||||
|
input = tf.layers.flatten(input)
|
||||||
|
mask = tf.to_float(mask)
|
||||||
|
|
||||||
|
h, next_state = gru(input, mask, state, nh=nlstm)
|
||||||
|
|
||||||
|
num_layers = 2
|
||||||
|
num_hidden = 64
|
||||||
|
activation = tf.nn.relu
|
||||||
|
for i in range(num_layers):
|
||||||
|
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
||||||
|
h = activation(h)
|
||||||
|
|
||||||
|
return h, next_state
|
||||||
|
|
||||||
|
return state, _network_fn
|
||||||
|
|
||||||
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
def lstm(x, m, s, scope, nh, init_scale=1.0):
|
def lstm(x, m, s, scope, nh, init_scale=1.0):
|
||||||
x = tf.layers.flatten(x)
|
x = tf.layers.flatten(x)
|
||||||
nin = x.get_shape()[1]
|
nin = x.get_shape()[1]
|
||||||
@@ -155,3 +242,19 @@ def lnlstm(x, m, s, scope, nh, init_scale=1.0):
|
|||||||
s = tf.concat(axis=1, values=[c, h])
|
s = tf.concat(axis=1, values=[c, h])
|
||||||
|
|
||||||
return h, s
|
return h, s
|
||||||
|
|
||||||
|
|
||||||
|
def gru(x, mask, state, nh, init_scale=-1.0):
|
||||||
|
"""Gated recurrent unit (GRU) with nunits cells."""
|
||||||
|
h = state
|
||||||
|
mask = tf.tile(tf.expand_dims(mask, axis=-1), multiples=[1, nh])
|
||||||
|
|
||||||
|
h *= (1.0 - mask)
|
||||||
|
hx = tf.concat([h, x], axis=1)
|
||||||
|
mr = tf.sigmoid(fc(hx, nh=nh * 2, scope='gru_mr', init_bias=init_scale))
|
||||||
|
# r: read strength. m: 'member strength
|
||||||
|
m, r = tf.split(mr, 2, axis=1)
|
||||||
|
rh_x = tf.concat([r * h, x], axis=1)
|
||||||
|
htil = tf.tanh(fc(rh_x, nh=nh, scope='gru_htil'))
|
||||||
|
h = m * h + (1.0 - m) * htil
|
||||||
|
return h, h
|
||||||
|
Reference in New Issue
Block a user