From cc4215ef4b2fbe06d1f40ac4ddf21cb13a9feb5c Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 7 Sep 2018 01:16:06 +0800 Subject: [PATCH] refactor common.models via registering reflection (#565) --- baselines/common/models.py | 90 ++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/baselines/common/models.py b/baselines/common/models.py index 6e1e177..ca07f84 100644 --- a/baselines/common/models.py +++ b/baselines/common/models.py @@ -5,6 +5,13 @@ from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch from baselines.common.mpi_running_mean_std import RunningMeanStd import tensorflow.contrib.layers as layers +mapping = {} + +def register(name): + def _thunk(func): + mapping[name] = func + return func + return _thunk def nature_cnn(unscaled_images, **conv_kwargs): """ @@ -20,6 +27,7 @@ def nature_cnn(unscaled_images, **conv_kwargs): return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) +@register("mlp") def mlp(num_layers=2, num_hidden=64, activation=tf.tanh): """ Stack of fully-connected layers to be used in a policy / q-function approximator @@ -28,16 +36,16 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh): ---------- num_layers: int number of fully-connected layers (default: 2) - + num_hidden: int size of fully-connected layers (default: 64) - + activation: activation function (default: tf.tanh) - + Returns: ------- function that builds fully connected network with a given input tensor / placeholder - """ + """ def network_fn(X): h = tf.layers.flatten(X) for i in range(num_layers): @@ -45,17 +53,20 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh): return h, None return network_fn - + +@register("cnn") def cnn(**conv_kwargs): def network_fn(X): return nature_cnn(X, **conv_kwargs), None return network_fn + +@register("cnn_small") def cnn_small(**conv_kwargs): def network_fn(X): h = tf.cast(X, tf.float32) / 255. - + activ = tf.nn.relu h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs)) h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs)) @@ -65,15 +76,15 @@ def cnn_small(**conv_kwargs): return network_fn - +@register("lstm") def lstm(nlstm=128, layer_norm=False): """ Builds LSTM (Long-Short Term Memory) network to be used in a policy. - Note that the resulting function returns not only the output of the LSTM + Note that the resulting function returns not only the output of the LSTM (i.e. hidden state of lstm for each step in the sequence), but also a dictionary - with auxiliary tensors to be set as policy attributes. + with auxiliary tensors to be set as policy attributes. - Specifically, + Specifically, S is a placeholder to feed current state (LSTM state has to be managed outside policy) M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too) initial_state is a numpy array containing initial lstm state (usually zeros) @@ -81,7 +92,7 @@ def lstm(nlstm=128, layer_norm=False): An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example - + Parameters: ---------- @@ -94,11 +105,11 @@ def lstm(nlstm=128, layer_norm=False): function that builds LSTM with a given input tensor / placeholder """ - + def network_fn(X, nenv=1): - nbatch = X.shape[0] + nbatch = X.shape[0] nsteps = nbatch // nenv - + h = tf.layers.flatten(X) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) @@ -111,7 +122,7 @@ def lstm(nlstm=128, layer_norm=False): h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm) else: h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm) - + h = seq_to_batch(h5) initial_state = np.zeros(S.shape.as_list(), dtype=float) @@ -120,13 +131,14 @@ def lstm(nlstm=128, layer_norm=False): return network_fn +@register("cnn_lstm") def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): def network_fn(X, nenv=1): - nbatch = X.shape[0] + nbatch = X.shape[0] nsteps = nbatch // nenv - + h = nature_cnn(X, **conv_kwargs) - + M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states @@ -137,7 +149,7 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm) else: h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm) - + h = seq_to_batch(h5) initial_state = np.zeros(S.shape.as_list(), dtype=float) @@ -145,23 +157,26 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): return network_fn + +@register("cnn_lnlstm") def cnn_lnlstm(nlstm=128, **conv_kwargs): return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) +@register("conv_only") def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs): - ''' + ''' convolutions-only net Parameters: ---------- - conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer. + conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer. Returns: function that takes tensorflow tensor as input and returns the output of the last convolutional layer - + ''' def network_fn(X): @@ -182,23 +197,22 @@ def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]): rms = RunningMeanStd(shape=x.shape[1:]) norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range)) return norm_x, rms - + def get_network_builder(name): - # TODO: replace with reflection? - if name == 'cnn': - return cnn - elif name == 'cnn_small': - return cnn_small - elif name == 'conv_only': - return conv_only - elif name == 'mlp': - return mlp - elif name == 'lstm': - return lstm - elif name == 'cnn_lstm': - return cnn_lstm - elif name == 'cnn_lnlstm': - return cnn_lnlstm + """ + If you want to register your own network outside models.py, you just need: + + Usage Example: + ------------- + from baselines.common.models import register + @register("your_network_name") + def your_network_define(**net_kwargs): + ... + return network_fn + + """ + if name in mapping: + return mapping[name] else: raise ValueError('Unknown network type: {}'.format(name))