add cnn_lnlstm layer.

This commit is contained in:
gyunt
2019-03-22 05:27:45 +09:00
parent 8d05917c87
commit c320622fe4

View File

@@ -83,6 +83,11 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
return network_fn
@register("ppo_cnn_lnlstm")
def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs):
return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
def lstm(x, m, s, scope, nh, init_scale=1.0):
x = tf.layers.flatten(x)
nin = x.get_shape()[1]