refactor common.models via registering reflection (#565)

This commit is contained in:
Tom
2018-09-07 01:16:06 +08:00
committed by pzhokhov
parent 1e9051e87e
commit cc4215ef4b

View File

@@ -5,6 +5,13 @@ from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
from baselines.common.mpi_running_mean_std import RunningMeanStd from baselines.common.mpi_running_mean_std import RunningMeanStd
import tensorflow.contrib.layers as layers import tensorflow.contrib.layers as layers
mapping = {}
def register(name):
def _thunk(func):
mapping[name] = func
return func
return _thunk
def nature_cnn(unscaled_images, **conv_kwargs): def nature_cnn(unscaled_images, **conv_kwargs):
""" """
@@ -20,6 +27,7 @@ def nature_cnn(unscaled_images, **conv_kwargs):
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh): def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
""" """
Stack of fully-connected layers to be used in a policy / q-function approximator Stack of fully-connected layers to be used in a policy / q-function approximator
@@ -28,16 +36,16 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
---------- ----------
num_layers: int number of fully-connected layers (default: 2) num_layers: int number of fully-connected layers (default: 2)
num_hidden: int size of fully-connected layers (default: 64) num_hidden: int size of fully-connected layers (default: 64)
activation: activation function (default: tf.tanh) activation: activation function (default: tf.tanh)
Returns: Returns:
------- -------
function that builds fully connected network with a given input tensor / placeholder function that builds fully connected network with a given input tensor / placeholder
""" """
def network_fn(X): def network_fn(X):
h = tf.layers.flatten(X) h = tf.layers.flatten(X)
for i in range(num_layers): for i in range(num_layers):
@@ -45,17 +53,20 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
return h, None return h, None
return network_fn return network_fn
@register("cnn")
def cnn(**conv_kwargs): def cnn(**conv_kwargs):
def network_fn(X): def network_fn(X):
return nature_cnn(X, **conv_kwargs), None return nature_cnn(X, **conv_kwargs), None
return network_fn return network_fn
@register("cnn_small")
def cnn_small(**conv_kwargs): def cnn_small(**conv_kwargs):
def network_fn(X): def network_fn(X):
h = tf.cast(X, tf.float32) / 255. h = tf.cast(X, tf.float32) / 255.
activ = tf.nn.relu activ = tf.nn.relu
h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs)) h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs))
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs)) h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
@@ -65,15 +76,15 @@ def cnn_small(**conv_kwargs):
return network_fn return network_fn
@register("lstm")
def lstm(nlstm=128, layer_norm=False): def lstm(nlstm=128, layer_norm=False):
""" """
Builds LSTM (Long-Short Term Memory) network to be used in a policy. Builds LSTM (Long-Short Term Memory) network to be used in a policy.
Note that the resulting function returns not only the output of the LSTM Note that the resulting function returns not only the output of the LSTM
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary (i.e. hidden state of lstm for each step in the sequence), but also a dictionary
with auxiliary tensors to be set as policy attributes. with auxiliary tensors to be set as policy attributes.
Specifically, Specifically,
S is a placeholder to feed current state (LSTM state has to be managed outside policy) S is a placeholder to feed current state (LSTM state has to be managed outside policy)
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too) M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
initial_state is a numpy array containing initial lstm state (usually zeros) initial_state is a numpy array containing initial lstm state (usually zeros)
@@ -81,7 +92,7 @@ def lstm(nlstm=128, layer_norm=False):
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
Parameters: Parameters:
---------- ----------
@@ -94,11 +105,11 @@ def lstm(nlstm=128, layer_norm=False):
function that builds LSTM with a given input tensor / placeholder function that builds LSTM with a given input tensor / placeholder
""" """
def network_fn(X, nenv=1): def network_fn(X, nenv=1):
nbatch = X.shape[0] nbatch = X.shape[0]
nsteps = nbatch // nenv nsteps = nbatch // nenv
h = tf.layers.flatten(X) h = tf.layers.flatten(X)
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
@@ -111,7 +122,7 @@ def lstm(nlstm=128, layer_norm=False):
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm) h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
else: else:
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm) h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
h = seq_to_batch(h5) h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float) initial_state = np.zeros(S.shape.as_list(), dtype=float)
@@ -120,13 +131,14 @@ def lstm(nlstm=128, layer_norm=False):
return network_fn return network_fn
@register("cnn_lstm")
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(X, nenv=1): def network_fn(X, nenv=1):
nbatch = X.shape[0] nbatch = X.shape[0]
nsteps = nbatch // nenv nsteps = nbatch // nenv
h = nature_cnn(X, **conv_kwargs) h = nature_cnn(X, **conv_kwargs)
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
@@ -137,7 +149,7 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm) h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
else: else:
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm) h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
h = seq_to_batch(h5) h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float) initial_state = np.zeros(S.shape.as_list(), dtype=float)
@@ -145,23 +157,26 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
return network_fn return network_fn
@register("cnn_lnlstm")
def cnn_lnlstm(nlstm=128, **conv_kwargs): def cnn_lnlstm(nlstm=128, **conv_kwargs):
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs): def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
''' '''
convolutions-only net convolutions-only net
Parameters: Parameters:
---------- ----------
conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer. conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
Returns: Returns:
function that takes tensorflow tensor as input and returns the output of the last convolutional layer function that takes tensorflow tensor as input and returns the output of the last convolutional layer
''' '''
def network_fn(X): def network_fn(X):
@@ -182,23 +197,22 @@ def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
rms = RunningMeanStd(shape=x.shape[1:]) rms = RunningMeanStd(shape=x.shape[1:])
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range)) norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
return norm_x, rms return norm_x, rms
def get_network_builder(name): def get_network_builder(name):
# TODO: replace with reflection? """
if name == 'cnn': If you want to register your own network outside models.py, you just need:
return cnn
elif name == 'cnn_small': Usage Example:
return cnn_small -------------
elif name == 'conv_only': from baselines.common.models import register
return conv_only @register("your_network_name")
elif name == 'mlp': def your_network_define(**net_kwargs):
return mlp ...
elif name == 'lstm': return network_fn
return lstm
elif name == 'cnn_lstm': """
return cnn_lstm if name in mapping:
elif name == 'cnn_lnlstm': return mapping[name]
return cnn_lnlstm
else: else:
raise ValueError('Unknown network type: {}'.format(name)) raise ValueError('Unknown network type: {}'.format(name))