removed ppo_lstm_mlp.

This commit is contained in:
gyunt
2019-04-09 02:04:35 +09:00
parent f63b09cf40
commit 02e26fd9df

View File

@@ -1,7 +1,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from baselines.a2c.utils import ortho_init, fc, lstm, lnlstm from baselines.a2c.utils import ortho_init, lstm, lnlstm
from baselines.common.models import register, nature_cnn from baselines.common.models import register, nature_cnn
@@ -53,26 +53,3 @@ def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
@register("ppo_cnn_lnlstm") @register("ppo_cnn_lnlstm")
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs): def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs) return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
@register("ppo_lstm_mlp")
def ppo_lstm_mlp(num_units=128, layer_norm=False):
def _network_fn(input, mask, state):
h = tf.layers.flatten(input)
mask = tf.to_float(mask)
if layer_norm:
h, next_state = lnlstm([h], [mask[:, None]], state, scope='lnlstm', nh=num_units)
else:
h, next_state = lstm([h], [mask[:, None]], state, scope='lstm', nh=num_units)
h = h[0]
num_layers = 2
num_hidden = 64
activation = tf.nn.relu
for i in range(num_layers):
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
h = activation(h)
return h, next_state
return RNN(_network_fn, num_units * 2)