diff --git a/baselines/ppo2/layers.py b/baselines/ppo2/layers.py index 8493f19..3a2445c 100644 --- a/baselines/ppo2/layers.py +++ b/baselines/ppo2/layers.py @@ -83,6 +83,11 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs): return network_fn +@register("ppo_cnn_lnlstm") +def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs): + return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) + + def lstm(x, m, s, scope, nh, init_scale=1.0): x = tf.layers.flatten(x) nin = x.get_shape()[1]