Compare commits
1 Commits
peterz_tfl
...
peterz_tfl
Author | SHA1 | Date | |
---|---|---|---|
|
dbcc4e0252 |
@@ -92,48 +92,6 @@ def lstm(nlstm=128, layer_norm=False):
|
||||
|
||||
return network_fn
|
||||
|
||||
def tflstm_static(nlstm=128, layer_norm=False):
|
||||
def network_fn(X, nenv=1):
|
||||
nbatch = X.shape[0]
|
||||
nsteps = nbatch // nenv
|
||||
|
||||
h = tf.layers.flatten(X)
|
||||
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(nlstm, state_is_tuple=False, forget_bias=0.0)
|
||||
|
||||
S = tf.placeholder(tf.float32, rnn_cell.zero_state(nenv, dtype=tf.float32).shape) #states
|
||||
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||
|
||||
xs = batch_to_seq(h, nenv, nsteps)
|
||||
|
||||
h5, snew = tf.nn.static_rnn(rnn_cell, xs, initial_state=S)
|
||||
|
||||
h = seq_to_batch(h5)
|
||||
|
||||
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||
|
||||
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||
|
||||
return network_fn
|
||||
|
||||
def tflstm(nlstm=128):
|
||||
def network_fn(X, nenv=1):
|
||||
nbatch = X.shape[0]
|
||||
nsteps = nbatch // nenv
|
||||
|
||||
h = tf.layers.flatten(X)
|
||||
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(nlstm, state_is_tuple=False, forget_bias=0.0)
|
||||
|
||||
S = tf.placeholder(tf.float32, rnn_cell.zero_state(nenv, dtype=tf.float32).shape) #states
|
||||
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||
initial_state = np.zeros(S.shape)
|
||||
|
||||
h = tf.reshape(h, (-1, nsteps, h.shape[-1]))
|
||||
h, snew = tf.nn.dynamic_rnn(rnn_cell, h, initial_state=S)
|
||||
|
||||
h = tf.reshape(h, (-1, h.shape[-1]))
|
||||
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||
|
||||
return network_fn
|
||||
|
||||
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||
def network_fn(X, nenv=1):
|
||||
@@ -211,10 +169,6 @@ def get_network_builder(name):
|
||||
return mlp
|
||||
elif name == 'lstm':
|
||||
return lstm
|
||||
elif name == 'tflstm_static':
|
||||
return tflstm_static
|
||||
elif name == 'tflstm':
|
||||
return tflstm
|
||||
elif name == 'cnn_lstm':
|
||||
return cnn_lstm
|
||||
elif name == 'cnn_lnlstm':
|
||||
|
Reference in New Issue
Block a user