add cnn_lnlstm layer.
This commit is contained in:
@@ -83,6 +83,11 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
|
|||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
@register("ppo_cnn_lnlstm")
|
||||||
|
def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs):
|
||||||
|
return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def lstm(x, m, s, scope, nh, init_scale=1.0):
|
def lstm(x, m, s, scope, nh, init_scale=1.0):
|
||||||
x = tf.layers.flatten(x)
|
x = tf.layers.flatten(x)
|
||||||
nin = x.get_shape()[1]
|
nin = x.get_shape()[1]
|
||||||
|
Reference in New Issue
Block a user