refactor common.models via registering reflection (#565)

This commit is contained in:
Tom
2018-09-07 01:16:06 +08:00
committed by pzhokhov
parent 1e9051e87e
commit cc4215ef4b

View File

@@ -5,6 +5,13 @@ from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
from baselines.common.mpi_running_mean_std import RunningMeanStd
import tensorflow.contrib.layers as layers
mapping = {}
def register(name):
def _thunk(func):
mapping[name] = func
return func
return _thunk
def nature_cnn(unscaled_images, **conv_kwargs):
"""
@@ -20,6 +27,7 @@ def nature_cnn(unscaled_images, **conv_kwargs):
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
"""
Stack of fully-connected layers to be used in a policy / q-function approximator
@@ -28,16 +36,16 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
----------
num_layers: int number of fully-connected layers (default: 2)
num_hidden: int size of fully-connected layers (default: 64)
activation: activation function (default: tf.tanh)
Returns:
-------
function that builds fully connected network with a given input tensor / placeholder
"""
"""
def network_fn(X):
h = tf.layers.flatten(X)
for i in range(num_layers):
@@ -45,17 +53,20 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
return h, None
return network_fn
@register("cnn")
def cnn(**conv_kwargs):
def network_fn(X):
return nature_cnn(X, **conv_kwargs), None
return network_fn
@register("cnn_small")
def cnn_small(**conv_kwargs):
def network_fn(X):
h = tf.cast(X, tf.float32) / 255.
activ = tf.nn.relu
h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs))
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
@@ -65,15 +76,15 @@ def cnn_small(**conv_kwargs):
return network_fn
@register("lstm")
def lstm(nlstm=128, layer_norm=False):
"""
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
Note that the resulting function returns not only the output of the LSTM
Note that the resulting function returns not only the output of the LSTM
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
with auxiliary tensors to be set as policy attributes.
with auxiliary tensors to be set as policy attributes.
Specifically,
Specifically,
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
initial_state is a numpy array containing initial lstm state (usually zeros)
@@ -81,7 +92,7 @@ def lstm(nlstm=128, layer_norm=False):
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
Parameters:
----------
@@ -94,11 +105,11 @@ def lstm(nlstm=128, layer_norm=False):
function that builds LSTM with a given input tensor / placeholder
"""
def network_fn(X, nenv=1):
nbatch = X.shape[0]
nbatch = X.shape[0]
nsteps = nbatch // nenv
h = tf.layers.flatten(X)
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
@@ -111,7 +122,7 @@ def lstm(nlstm=128, layer_norm=False):
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
else:
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float)
@@ -120,13 +131,14 @@ def lstm(nlstm=128, layer_norm=False):
return network_fn
@register("cnn_lstm")
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(X, nenv=1):
nbatch = X.shape[0]
nbatch = X.shape[0]
nsteps = nbatch // nenv
h = nature_cnn(X, **conv_kwargs)
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
@@ -137,7 +149,7 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
else:
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float)
@@ -145,23 +157,26 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
return network_fn
@register("cnn_lnlstm")
def cnn_lnlstm(nlstm=128, **conv_kwargs):
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
'''
'''
convolutions-only net
Parameters:
----------
conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
Returns:
function that takes tensorflow tensor as input and returns the output of the last convolutional layer
'''
def network_fn(X):
@@ -182,23 +197,22 @@ def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
rms = RunningMeanStd(shape=x.shape[1:])
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
return norm_x, rms
def get_network_builder(name):
# TODO: replace with reflection?
if name == 'cnn':
return cnn
elif name == 'cnn_small':
return cnn_small
elif name == 'conv_only':
return conv_only
elif name == 'mlp':
return mlp
elif name == 'lstm':
return lstm
elif name == 'cnn_lstm':
return cnn_lstm
elif name == 'cnn_lnlstm':
return cnn_lnlstm
"""
If you want to register your own network outside models.py, you just need:
Usage Example:
-------------
from baselines.common.models import register
@register("your_network_name")
def your_network_define(**net_kwargs):
...
return network_fn
"""
if name in mapping:
return mapping[name]
else:
raise ValueError('Unknown network type: {}'.format(name))