lstm network builders using tf lstm
This commit is contained in:
@@ -92,6 +92,48 @@ def lstm(nlstm=128, layer_norm=False):
|
|||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
def tflstm_static(nlstm=128, layer_norm=False):
|
||||||
|
def network_fn(X, nenv=1):
|
||||||
|
nbatch = X.shape[0]
|
||||||
|
nsteps = nbatch // nenv
|
||||||
|
|
||||||
|
h = tf.layers.flatten(X)
|
||||||
|
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(nlstm, state_is_tuple=False, forget_bias=0.0)
|
||||||
|
|
||||||
|
S = tf.placeholder(tf.float32, rnn_cell.zero_state(nenv, dtype=tf.float32).shape) #states
|
||||||
|
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||||
|
|
||||||
|
xs = batch_to_seq(h, nenv, nsteps)
|
||||||
|
|
||||||
|
h5, snew = tf.nn.static_rnn(rnn_cell, xs, initial_state=S)
|
||||||
|
|
||||||
|
h = seq_to_batch(h5)
|
||||||
|
|
||||||
|
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||||
|
|
||||||
|
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||||
|
|
||||||
|
return network_fn
|
||||||
|
|
||||||
|
def tflstm(nlstm=128):
|
||||||
|
def network_fn(X, nenv=1):
|
||||||
|
nbatch = X.shape[0]
|
||||||
|
nsteps = nbatch // nenv
|
||||||
|
|
||||||
|
h = tf.layers.flatten(X)
|
||||||
|
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(nlstm, state_is_tuple=False, forget_bias=0.0)
|
||||||
|
|
||||||
|
S = tf.placeholder(tf.float32, rnn_cell.zero_state(nenv, dtype=tf.float32).shape) #states
|
||||||
|
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||||
|
initial_state = np.zeros(S.shape)
|
||||||
|
|
||||||
|
h = tf.reshape(h, (-1, nsteps, h.shape[-1]))
|
||||||
|
h, snew = tf.nn.dynamic_rnn(rnn_cell, h, initial_state=S)
|
||||||
|
|
||||||
|
h = tf.reshape(h, (-1, h.shape[-1]))
|
||||||
|
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||||
|
|
||||||
|
return network_fn
|
||||||
|
|
||||||
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||||
def network_fn(X, nenv=1):
|
def network_fn(X, nenv=1):
|
||||||
@@ -169,6 +211,10 @@ def get_network_builder(name):
|
|||||||
return mlp
|
return mlp
|
||||||
elif name == 'lstm':
|
elif name == 'lstm':
|
||||||
return lstm
|
return lstm
|
||||||
|
elif name == 'tflstm_static':
|
||||||
|
return tflstm_static
|
||||||
|
elif name == 'tflstm':
|
||||||
|
return tflstm
|
||||||
elif name == 'cnn_lstm':
|
elif name == 'cnn_lstm':
|
||||||
return cnn_lstm
|
return cnn_lstm
|
||||||
elif name == 'cnn_lnlstm':
|
elif name == 'cnn_lnlstm':
|
||||||
|
@@ -6,7 +6,8 @@ from baselines.run import get_learn_function
|
|||||||
|
|
||||||
common_kwargs = dict(
|
common_kwargs = dict(
|
||||||
seed=0,
|
seed=0,
|
||||||
total_timesteps=50000,
|
total_timesteps=20000,
|
||||||
|
nlstm=64
|
||||||
)
|
)
|
||||||
|
|
||||||
learn_kwargs = {
|
learn_kwargs = {
|
||||||
@@ -19,7 +20,7 @@ learn_kwargs = {
|
|||||||
|
|
||||||
|
|
||||||
alg_list = learn_kwargs.keys()
|
alg_list = learn_kwargs.keys()
|
||||||
rnn_list = ['lstm']
|
rnn_list = ['lstm', 'tflstm', 'tflstm_static']
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("alg", alg_list)
|
@pytest.mark.parametrize("alg", alg_list)
|
||||||
@@ -41,11 +42,11 @@ def test_fixed_sequence(alg, rnn):
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
simple_test(env_fn, learn, 0.7)
|
simple_test(env_fn, learn, 0.3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_fixed_sequence('ppo2', 'lstm')
|
test_fixed_sequence('ppo2', 'tflstm')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -2,6 +2,7 @@ import tensorflow as tf
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces import np_random
|
from gym.spaces import np_random
|
||||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
|
from baselines.bench.monitor import Monitor
|
||||||
|
|
||||||
N_TRIALS = 10000
|
N_TRIALS = 10000
|
||||||
N_EPISODES = 100
|
N_EPISODES = 100
|
||||||
@@ -10,7 +11,7 @@ def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
|||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
np_random.seed(0)
|
np_random.seed(0)
|
||||||
|
|
||||||
env = DummyVecEnv([env_fn])
|
env = DummyVecEnv([lambda: Monitor(env_fn(), None, allow_early_resets=True)])
|
||||||
|
|
||||||
|
|
||||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||||
|
Reference in New Issue
Block a user