add RNN class.
This commit is contained in:
@@ -1,18 +1,32 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
from baselines.a2c import utils
|
||||
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
|
||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
mapping = {}
|
||||
|
||||
def register(name):
|
||||
|
||||
def register(name, is_rnn=False):
|
||||
def _thunk(func):
|
||||
if is_rnn:
|
||||
func = RNN(func)
|
||||
mapping[name] = func
|
||||
return func
|
||||
|
||||
return _thunk
|
||||
|
||||
|
||||
class RNN(object):
|
||||
def __init__(self, func):
|
||||
self._func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._func(*args, **kwargs)
|
||||
|
||||
|
||||
def nature_cnn(unscaled_images, **conv_kwargs):
|
||||
"""
|
||||
CNN from Nature paper.
|
||||
@@ -46,6 +60,7 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
|
||||
|
||||
function that builds fully connected network with a given input tensor / placeholder
|
||||
"""
|
||||
|
||||
def network_fn(X):
|
||||
h = tf.layers.flatten(X)
|
||||
for i in range(num_layers):
|
||||
@@ -63,6 +78,7 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
|
||||
def cnn(**conv_kwargs):
|
||||
def network_fn(X):
|
||||
return nature_cnn(X, **conv_kwargs)
|
||||
|
||||
return network_fn
|
||||
|
||||
|
||||
@@ -77,10 +93,11 @@ def cnn_small(**conv_kwargs):
|
||||
h = conv_to_fc(h)
|
||||
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
|
||||
return h
|
||||
|
||||
return network_fn
|
||||
|
||||
|
||||
@register("lstm")
|
||||
@register("lstm", is_rnn=True)
|
||||
def lstm(nlstm=128, layer_norm=False):
|
||||
"""
|
||||
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
|
||||
@@ -116,8 +133,8 @@ def lstm(nlstm=128, layer_norm=False):
|
||||
|
||||
h = tf.layers.flatten(X)
|
||||
|
||||
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
|
||||
M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1)
|
||||
S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
|
||||
|
||||
xs = batch_to_seq(h, nenv, nsteps)
|
||||
ms = batch_to_seq(M, nenv, nsteps)
|
||||
@@ -130,12 +147,12 @@ def lstm(nlstm=128, layer_norm=False):
|
||||
h = seq_to_batch(h5)
|
||||
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||
|
||||
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
|
||||
|
||||
return network_fn
|
||||
|
||||
|
||||
@register("cnn_lstm")
|
||||
@register("cnn_lstm", is_rnn=True)
|
||||
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||
def network_fn(X, nenv=1):
|
||||
nbatch = X.shape[0]
|
||||
@@ -143,8 +160,8 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||
|
||||
h = nature_cnn(X, **conv_kwargs)
|
||||
|
||||
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
|
||||
M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1)
|
||||
S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
|
||||
|
||||
xs = batch_to_seq(h, nenv, nsteps)
|
||||
ms = batch_to_seq(M, nenv, nsteps)
|
||||
@@ -157,12 +174,12 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||
h = seq_to_batch(h5)
|
||||
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||
|
||||
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
|
||||
|
||||
return network_fn
|
||||
|
||||
|
||||
@register("cnn_lnlstm")
|
||||
@register("cnn_lnlstm", is_rnn=True)
|
||||
def cnn_lnlstm(nlstm=128, **conv_kwargs):
|
||||
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
||||
|
||||
@@ -195,8 +212,10 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
|
||||
**conv_kwargs)
|
||||
|
||||
return out
|
||||
|
||||
return network_fn
|
||||
|
||||
|
||||
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
||||
rms = RunningMeanStd(shape=x.shape[1:])
|
||||
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
||||
|
@@ -2,10 +2,10 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.a2c.utils import ortho_init, fc
|
||||
from baselines.common.models import register
|
||||
from baselines.common.models import register, nature_cnn, RNN
|
||||
|
||||
|
||||
@register("ppo_lstm")
|
||||
@register("ppo_lstm", is_rnn=True)
|
||||
def ppo_lstm(nlstm=128, layer_norm=False):
|
||||
def network_fn(input, mask):
|
||||
memory_size = nlstm * 2
|
||||
@@ -27,13 +27,13 @@ def ppo_lstm(nlstm=128, layer_norm=False):
|
||||
h, next_state = lstm(input, mask, state, scope='lstm', nh=nlstm)
|
||||
return h, next_state
|
||||
|
||||
return state, _network_fn
|
||||
return state, RNN(_network_fn)
|
||||
|
||||
return network_fn
|
||||
return RNN(network_fn)
|
||||
|
||||
|
||||
@register("ppo_cnn_lstm")
|
||||
def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
|
||||
@register("ppo_cnn_lstm", is_rnn=True)
|
||||
def ppo_cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||
def network_fn(input, mask):
|
||||
memory_size = nlstm * 2
|
||||
nbatch = input.shape[0]
|
||||
@@ -48,27 +48,7 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
|
||||
mask = tf.to_float(mask)
|
||||
initializer = ortho_init(np.sqrt(2))
|
||||
|
||||
h = tf.contrib.layers.conv2d(input,
|
||||
num_outputs=32,
|
||||
kernel_size=8,
|
||||
stride=4,
|
||||
padding=pad,
|
||||
weights_initializer=initializer,
|
||||
**conv_kwargs)
|
||||
h = tf.contrib.layers.conv2d(h,
|
||||
num_outputs=64,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=pad,
|
||||
weights_initializer=initializer,
|
||||
**conv_kwargs)
|
||||
h = tf.contrib.layers.conv2d(h,
|
||||
num_outputs=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=pad,
|
||||
weights_initializer=initializer,
|
||||
**conv_kwargs)
|
||||
h = nature_cnn(input, **conv_kwargs)
|
||||
h = tf.layers.flatten(h)
|
||||
h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer)
|
||||
|
||||
@@ -78,17 +58,17 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
|
||||
h, next_state = lstm(h, mask, state, scope='lstm', nh=nlstm)
|
||||
return h, next_state
|
||||
|
||||
return state, _network_fn
|
||||
return state, RNN(_network_fn)
|
||||
|
||||
return network_fn
|
||||
return RNN(network_fn)
|
||||
|
||||
|
||||
@register("ppo_cnn_lnlstm")
|
||||
@register("ppo_cnn_lnlstm", is_rnn=True)
|
||||
def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs):
|
||||
return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
||||
|
||||
|
||||
@register("ppo_gru")
|
||||
@register("ppo_gru", is_rnn=True)
|
||||
def ppo_gru(nlstm=128):
|
||||
def network_fn(input, mask):
|
||||
memory_size = nlstm
|
||||
@@ -107,12 +87,12 @@ def ppo_gru(nlstm=128):
|
||||
h, next_state = gru(input, mask, state, nh=nlstm)
|
||||
return h, next_state
|
||||
|
||||
return state, _network_fn
|
||||
return state, RNN(_network_fn)
|
||||
|
||||
return network_fn
|
||||
return RNN(network_fn)
|
||||
|
||||
|
||||
@register("ppo_lstm_mlp")
|
||||
@register("ppo_lstm_mlp", is_rnn=True)
|
||||
def ppo_lstm_mlp(nlstm=128, layer_norm=False):
|
||||
def network_fn(input, mask):
|
||||
memory_size = nlstm * 2
|
||||
@@ -138,12 +118,12 @@ def ppo_lstm_mlp(nlstm=128, layer_norm=False):
|
||||
h = activation(h)
|
||||
return h, next_state
|
||||
|
||||
return state, _network_fn
|
||||
return state, RNN(_network_fn)
|
||||
|
||||
return network_fn
|
||||
return RNN(network_fn)
|
||||
|
||||
|
||||
@register("ppo_gru_mlp")
|
||||
@register("ppo_gru_mlp", is_rnn=True)
|
||||
def ppo_gru_mlp(nlstm=128):
|
||||
def network_fn(input, mask):
|
||||
memory_size = nlstm
|
||||
@@ -170,9 +150,9 @@ def ppo_gru_mlp(nlstm=128):
|
||||
|
||||
return h, next_state
|
||||
|
||||
return state, _network_fn
|
||||
return state, RNN(_network_fn)
|
||||
|
||||
return network_fn
|
||||
return RNN(network_fn)
|
||||
|
||||
|
||||
def lstm(x, m, s, scope, nh, init_scale=1.0):
|
||||
|
@@ -1,12 +1,12 @@
|
||||
import inspect
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.a2c.utils import fc
|
||||
from baselines.common import tf_util
|
||||
from baselines.common.distributions import make_pdtype
|
||||
from baselines.common.input import observation_placeholder, encode_observation
|
||||
from baselines.common.models import RNN
|
||||
from baselines.common.models import get_network_builder
|
||||
from baselines.common.tf_util import adjust_shape
|
||||
|
||||
@@ -125,7 +125,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
|
||||
encoded_x = encode_observation(ob_space, X)
|
||||
|
||||
with tf.variable_scope('load_rnn_memory'):
|
||||
if is_rnn_network(policy_network):
|
||||
if isinstance(policy_network, RNN):
|
||||
policy_state, policy_network_ = policy_network(encoded_x, dones)
|
||||
else:
|
||||
policy_network_ = policy_network
|
||||
@@ -139,7 +139,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
|
||||
assert callable(value_network)
|
||||
value_network_ = value_network
|
||||
|
||||
if is_rnn_network(value_network_):
|
||||
if isinstance(value_network_, RNN):
|
||||
value_state, value_network_ = value_network_(encoded_x, dones)
|
||||
|
||||
if policy_state or value_state:
|
||||
@@ -154,7 +154,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
|
||||
index += size
|
||||
|
||||
with tf.variable_scope('policy_latent', reuse=tf.AUTO_REUSE):
|
||||
if is_rnn_network(policy_network_):
|
||||
if isinstance(policy_network_, RNN):
|
||||
policy_latent, next_policy_state = \
|
||||
policy_network_(encoded_x, dones, state_map[policy_state])
|
||||
next_states_list.append(next_policy_state)
|
||||
@@ -164,7 +164,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
|
||||
with tf.variable_scope('value_latent', reuse=tf.AUTO_REUSE):
|
||||
if value_network_ == 'shared':
|
||||
value_latent = policy_latent
|
||||
elif is_rnn_network(value_network_):
|
||||
elif isinstance(value_network_, RNN):
|
||||
value_latent, next_value_state = \
|
||||
value_network_(encoded_x, dones, state_map[value_state])
|
||||
next_states_list.append(next_value_state)
|
||||
@@ -201,7 +201,3 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
|
||||
return policy
|
||||
|
||||
return policy_fn
|
||||
|
||||
|
||||
def is_rnn_network(network):
|
||||
return 'mask' in inspect.getfullargspec(network).args
|
||||
|
Reference in New Issue
Block a user