removed ppo_lstm_mlp
.
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from baselines.a2c.utils import ortho_init, fc, lstm, lnlstm
|
from baselines.a2c.utils import ortho_init, lstm, lnlstm
|
||||||
from baselines.common.models import register, nature_cnn
|
from baselines.common.models import register, nature_cnn
|
||||||
|
|
||||||
|
|
||||||
@@ -53,26 +53,3 @@ def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
|
|||||||
@register("ppo_cnn_lnlstm")
|
@register("ppo_cnn_lnlstm")
|
||||||
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
|
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
|
||||||
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
|
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register("ppo_lstm_mlp")
|
|
||||||
def ppo_lstm_mlp(num_units=128, layer_norm=False):
|
|
||||||
def _network_fn(input, mask, state):
|
|
||||||
h = tf.layers.flatten(input)
|
|
||||||
mask = tf.to_float(mask)
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
h, next_state = lnlstm([h], [mask[:, None]], state, scope='lnlstm', nh=num_units)
|
|
||||||
else:
|
|
||||||
h, next_state = lstm([h], [mask[:, None]], state, scope='lstm', nh=num_units)
|
|
||||||
h = h[0]
|
|
||||||
|
|
||||||
num_layers = 2
|
|
||||||
num_hidden = 64
|
|
||||||
activation = tf.nn.relu
|
|
||||||
for i in range(num_layers):
|
|
||||||
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
|
||||||
h = activation(h)
|
|
||||||
return h, next_state
|
|
||||||
|
|
||||||
return RNN(_network_fn, num_units * 2)
|
|
||||||
|
Reference in New Issue
Block a user