refactor common.models via registering reflection (#565)

This commit is contained in:
Tom
2018-09-07 01:16:06 +08:00
committed by pzhokhov
parent 1e9051e87e
commit cc4215ef4b

View File

@@ -5,6 +5,13 @@ from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
from baselines.common.mpi_running_mean_std import RunningMeanStd from baselines.common.mpi_running_mean_std import RunningMeanStd
import tensorflow.contrib.layers as layers import tensorflow.contrib.layers as layers
mapping = {}
def register(name):
def _thunk(func):
mapping[name] = func
return func
return _thunk
def nature_cnn(unscaled_images, **conv_kwargs): def nature_cnn(unscaled_images, **conv_kwargs):
""" """
@@ -20,6 +27,7 @@ def nature_cnn(unscaled_images, **conv_kwargs):
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh): def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
""" """
Stack of fully-connected layers to be used in a policy / q-function approximator Stack of fully-connected layers to be used in a policy / q-function approximator
@@ -47,11 +55,14 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
return network_fn return network_fn
@register("cnn")
def cnn(**conv_kwargs): def cnn(**conv_kwargs):
def network_fn(X): def network_fn(X):
return nature_cnn(X, **conv_kwargs), None return nature_cnn(X, **conv_kwargs), None
return network_fn return network_fn
@register("cnn_small")
def cnn_small(**conv_kwargs): def cnn_small(**conv_kwargs):
def network_fn(X): def network_fn(X):
h = tf.cast(X, tf.float32) / 255. h = tf.cast(X, tf.float32) / 255.
@@ -65,7 +76,7 @@ def cnn_small(**conv_kwargs):
return network_fn return network_fn
@register("lstm")
def lstm(nlstm=128, layer_norm=False): def lstm(nlstm=128, layer_norm=False):
""" """
Builds LSTM (Long-Short Term Memory) network to be used in a policy. Builds LSTM (Long-Short Term Memory) network to be used in a policy.
@@ -120,6 +131,7 @@ def lstm(nlstm=128, layer_norm=False):
return network_fn return network_fn
@register("cnn_lstm")
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(X, nenv=1): def network_fn(X, nenv=1):
nbatch = X.shape[0] nbatch = X.shape[0]
@@ -145,10 +157,13 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
return network_fn return network_fn
@register("cnn_lnlstm")
def cnn_lnlstm(nlstm=128, **conv_kwargs): def cnn_lnlstm(nlstm=128, **conv_kwargs):
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs): def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
''' '''
convolutions-only net convolutions-only net
@@ -185,20 +200,19 @@ def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
def get_network_builder(name): def get_network_builder(name):
# TODO: replace with reflection? """
if name == 'cnn': If you want to register your own network outside models.py, you just need:
return cnn
elif name == 'cnn_small': Usage Example:
return cnn_small -------------
elif name == 'conv_only': from baselines.common.models import register
return conv_only @register("your_network_name")
elif name == 'mlp': def your_network_define(**net_kwargs):
return mlp ...
elif name == 'lstm': return network_fn
return lstm
elif name == 'cnn_lstm': """
return cnn_lstm if name in mapping:
elif name == 'cnn_lnlstm': return mapping[name]
return cnn_lnlstm
else: else:
raise ValueError('Unknown network type: {}'.format(name)) raise ValueError('Unknown network type: {}'.format(name))