refactor common.models via registering reflection (#565)
This commit is contained in:
@@ -5,6 +5,13 @@ from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
|
|||||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||||
import tensorflow.contrib.layers as layers
|
import tensorflow.contrib.layers as layers
|
||||||
|
|
||||||
|
mapping = {}
|
||||||
|
|
||||||
|
def register(name):
|
||||||
|
def _thunk(func):
|
||||||
|
mapping[name] = func
|
||||||
|
return func
|
||||||
|
return _thunk
|
||||||
|
|
||||||
def nature_cnn(unscaled_images, **conv_kwargs):
|
def nature_cnn(unscaled_images, **conv_kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -20,6 +27,7 @@ def nature_cnn(unscaled_images, **conv_kwargs):
|
|||||||
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
|
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
|
||||||
|
|
||||||
|
|
||||||
|
@register("mlp")
|
||||||
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
|
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
|
||||||
"""
|
"""
|
||||||
Stack of fully-connected layers to be used in a policy / q-function approximator
|
Stack of fully-connected layers to be used in a policy / q-function approximator
|
||||||
@@ -28,16 +36,16 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
|
|||||||
----------
|
----------
|
||||||
|
|
||||||
num_layers: int number of fully-connected layers (default: 2)
|
num_layers: int number of fully-connected layers (default: 2)
|
||||||
|
|
||||||
num_hidden: int size of fully-connected layers (default: 64)
|
num_hidden: int size of fully-connected layers (default: 64)
|
||||||
|
|
||||||
activation: activation function (default: tf.tanh)
|
activation: activation function (default: tf.tanh)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
-------
|
-------
|
||||||
|
|
||||||
function that builds fully connected network with a given input tensor / placeholder
|
function that builds fully connected network with a given input tensor / placeholder
|
||||||
"""
|
"""
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
h = tf.layers.flatten(X)
|
h = tf.layers.flatten(X)
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
@@ -45,17 +53,20 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
|
|||||||
return h, None
|
return h, None
|
||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@register("cnn")
|
||||||
def cnn(**conv_kwargs):
|
def cnn(**conv_kwargs):
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
return nature_cnn(X, **conv_kwargs), None
|
return nature_cnn(X, **conv_kwargs), None
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
@register("cnn_small")
|
||||||
def cnn_small(**conv_kwargs):
|
def cnn_small(**conv_kwargs):
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
h = tf.cast(X, tf.float32) / 255.
|
h = tf.cast(X, tf.float32) / 255.
|
||||||
|
|
||||||
activ = tf.nn.relu
|
activ = tf.nn.relu
|
||||||
h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs))
|
h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs))
|
||||||
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
|
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
|
||||||
@@ -65,15 +76,15 @@ def cnn_small(**conv_kwargs):
|
|||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
@register("lstm")
|
||||||
def lstm(nlstm=128, layer_norm=False):
|
def lstm(nlstm=128, layer_norm=False):
|
||||||
"""
|
"""
|
||||||
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
|
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
|
||||||
Note that the resulting function returns not only the output of the LSTM
|
Note that the resulting function returns not only the output of the LSTM
|
||||||
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
|
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
|
||||||
with auxiliary tensors to be set as policy attributes.
|
with auxiliary tensors to be set as policy attributes.
|
||||||
|
|
||||||
Specifically,
|
Specifically,
|
||||||
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
|
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
|
||||||
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
|
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
|
||||||
initial_state is a numpy array containing initial lstm state (usually zeros)
|
initial_state is a numpy array containing initial lstm state (usually zeros)
|
||||||
@@ -81,7 +92,7 @@ def lstm(nlstm=128, layer_norm=False):
|
|||||||
|
|
||||||
|
|
||||||
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
|
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
----------
|
----------
|
||||||
|
|
||||||
@@ -94,11 +105,11 @@ def lstm(nlstm=128, layer_norm=False):
|
|||||||
|
|
||||||
function that builds LSTM with a given input tensor / placeholder
|
function that builds LSTM with a given input tensor / placeholder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def network_fn(X, nenv=1):
|
def network_fn(X, nenv=1):
|
||||||
nbatch = X.shape[0]
|
nbatch = X.shape[0]
|
||||||
nsteps = nbatch // nenv
|
nsteps = nbatch // nenv
|
||||||
|
|
||||||
h = tf.layers.flatten(X)
|
h = tf.layers.flatten(X)
|
||||||
|
|
||||||
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||||
@@ -111,7 +122,7 @@ def lstm(nlstm=128, layer_norm=False):
|
|||||||
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
|
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
|
||||||
else:
|
else:
|
||||||
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
|
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
|
||||||
|
|
||||||
h = seq_to_batch(h5)
|
h = seq_to_batch(h5)
|
||||||
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||||
|
|
||||||
@@ -120,13 +131,14 @@ def lstm(nlstm=128, layer_norm=False):
|
|||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
@register("cnn_lstm")
|
||||||
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||||
def network_fn(X, nenv=1):
|
def network_fn(X, nenv=1):
|
||||||
nbatch = X.shape[0]
|
nbatch = X.shape[0]
|
||||||
nsteps = nbatch // nenv
|
nsteps = nbatch // nenv
|
||||||
|
|
||||||
h = nature_cnn(X, **conv_kwargs)
|
h = nature_cnn(X, **conv_kwargs)
|
||||||
|
|
||||||
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||||
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
|
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
|
||||||
|
|
||||||
@@ -137,7 +149,7 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
|||||||
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
|
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
|
||||||
else:
|
else:
|
||||||
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
|
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
|
||||||
|
|
||||||
h = seq_to_batch(h5)
|
h = seq_to_batch(h5)
|
||||||
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||||
|
|
||||||
@@ -145,23 +157,26 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
|||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
|
@register("cnn_lnlstm")
|
||||||
def cnn_lnlstm(nlstm=128, **conv_kwargs):
|
def cnn_lnlstm(nlstm=128, **conv_kwargs):
|
||||||
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register("conv_only")
|
||||||
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
|
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
|
||||||
'''
|
'''
|
||||||
convolutions-only net
|
convolutions-only net
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
----------
|
----------
|
||||||
|
|
||||||
conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
|
conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
function that takes tensorflow tensor as input and returns the output of the last convolutional layer
|
function that takes tensorflow tensor as input and returns the output of the last convolutional layer
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
@@ -182,23 +197,22 @@ def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
|||||||
rms = RunningMeanStd(shape=x.shape[1:])
|
rms = RunningMeanStd(shape=x.shape[1:])
|
||||||
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
||||||
return norm_x, rms
|
return norm_x, rms
|
||||||
|
|
||||||
|
|
||||||
def get_network_builder(name):
|
def get_network_builder(name):
|
||||||
# TODO: replace with reflection?
|
"""
|
||||||
if name == 'cnn':
|
If you want to register your own network outside models.py, you just need:
|
||||||
return cnn
|
|
||||||
elif name == 'cnn_small':
|
Usage Example:
|
||||||
return cnn_small
|
-------------
|
||||||
elif name == 'conv_only':
|
from baselines.common.models import register
|
||||||
return conv_only
|
@register("your_network_name")
|
||||||
elif name == 'mlp':
|
def your_network_define(**net_kwargs):
|
||||||
return mlp
|
...
|
||||||
elif name == 'lstm':
|
return network_fn
|
||||||
return lstm
|
|
||||||
elif name == 'cnn_lstm':
|
"""
|
||||||
return cnn_lstm
|
if name in mapping:
|
||||||
elif name == 'cnn_lnlstm':
|
return mapping[name]
|
||||||
return cnn_lnlstm
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown network type: {}'.format(name))
|
raise ValueError('Unknown network type: {}'.format(name))
|
||||||
|
Reference in New Issue
Block a user