removed ppo_lstm_mlp
.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.a2c.utils import ortho_init, fc, lstm, lnlstm
|
||||
from baselines.a2c.utils import ortho_init, lstm, lnlstm
|
||||
from baselines.common.models import register, nature_cnn
|
||||
|
||||
|
||||
@@ -53,26 +53,3 @@ def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
|
||||
@register("ppo_cnn_lnlstm")
|
||||
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
|
||||
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
|
||||
|
||||
|
||||
@register("ppo_lstm_mlp")
|
||||
def ppo_lstm_mlp(num_units=128, layer_norm=False):
|
||||
def _network_fn(input, mask, state):
|
||||
h = tf.layers.flatten(input)
|
||||
mask = tf.to_float(mask)
|
||||
|
||||
if layer_norm:
|
||||
h, next_state = lnlstm([h], [mask[:, None]], state, scope='lnlstm', nh=num_units)
|
||||
else:
|
||||
h, next_state = lstm([h], [mask[:, None]], state, scope='lstm', nh=num_units)
|
||||
h = h[0]
|
||||
|
||||
num_layers = 2
|
||||
num_hidden = 64
|
||||
activation = tf.nn.relu
|
||||
for i in range(num_layers):
|
||||
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
||||
h = activation(h)
|
||||
return h, next_state
|
||||
|
||||
return RNN(_network_fn, num_units * 2)
|
||||
|
Reference in New Issue
Block a user