diff --git a/baselines/common/models.py b/baselines/common/models.py index 0003079..3630337 100644 --- a/baselines/common/models.py +++ b/baselines/common/models.py @@ -1,18 +1,32 @@ import numpy as np import tensorflow as tf +import tensorflow.contrib.layers as layers + from baselines.a2c import utils from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch from baselines.common.mpi_running_mean_std import RunningMeanStd -import tensorflow.contrib.layers as layers mapping = {} -def register(name): + +def register(name, is_rnn=False): def _thunk(func): + if is_rnn: + func = RNN(func) mapping[name] = func return func + return _thunk + +class RNN(object): + def __init__(self, func): + self._func = func + + def __call__(self, *args, **kwargs): + return self._func(*args, **kwargs) + + def nature_cnn(unscaled_images, **conv_kwargs): """ CNN from Nature paper. @@ -46,6 +60,7 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False): function that builds fully connected network with a given input tensor / placeholder """ + def network_fn(X): h = tf.layers.flatten(X) for i in range(num_layers): @@ -63,6 +78,7 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False): def cnn(**conv_kwargs): def network_fn(X): return nature_cnn(X, **conv_kwargs) + return network_fn @@ -77,10 +93,11 @@ def cnn_small(**conv_kwargs): h = conv_to_fc(h) h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2))) return h + return network_fn -@register("lstm") +@register("lstm", is_rnn=True) def lstm(nlstm=128, layer_norm=False): """ Builds LSTM (Long-Short Term Memory) network to be used in a policy. @@ -116,8 +133,8 @@ def lstm(nlstm=128, layer_norm=False): h = tf.layers.flatten(X) - M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) - S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states + M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1) + S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) @@ -130,12 +147,12 @@ def lstm(nlstm=128, layer_norm=False): h = seq_to_batch(h5) initial_state = np.zeros(S.shape.as_list(), dtype=float) - return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state} + return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state} return network_fn -@register("cnn_lstm") +@register("cnn_lstm", is_rnn=True) def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): def network_fn(X, nenv=1): nbatch = X.shape[0] @@ -143,8 +160,8 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): h = nature_cnn(X, **conv_kwargs) - M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) - S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states + M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1) + S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) @@ -157,12 +174,12 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): h = seq_to_batch(h5) initial_state = np.zeros(S.shape.as_list(), dtype=float) - return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state} + return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state} return network_fn -@register("cnn_lnlstm") +@register("cnn_lnlstm", is_rnn=True) def cnn_lnlstm(nlstm=128, **conv_kwargs): return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) @@ -195,8 +212,10 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs): **conv_kwargs) return out + return network_fn + def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]): rms = RunningMeanStd(shape=x.shape[1:]) norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range)) diff --git a/baselines/ppo2/layers.py b/baselines/ppo2/layers.py index eac63ab..69d0cb5 100644 --- a/baselines/ppo2/layers.py +++ b/baselines/ppo2/layers.py @@ -2,10 +2,10 @@ import numpy as np import tensorflow as tf from baselines.a2c.utils import ortho_init, fc -from baselines.common.models import register +from baselines.common.models import register, nature_cnn, RNN -@register("ppo_lstm") +@register("ppo_lstm", is_rnn=True) def ppo_lstm(nlstm=128, layer_norm=False): def network_fn(input, mask): memory_size = nlstm * 2 @@ -27,13 +27,13 @@ def ppo_lstm(nlstm=128, layer_norm=False): h, next_state = lstm(input, mask, state, scope='lstm', nh=nlstm) return h, next_state - return state, _network_fn + return state, RNN(_network_fn) - return network_fn + return RNN(network_fn) -@register("ppo_cnn_lstm") -def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs): +@register("ppo_cnn_lstm", is_rnn=True) +def ppo_cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): def network_fn(input, mask): memory_size = nlstm * 2 nbatch = input.shape[0] @@ -48,27 +48,7 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs): mask = tf.to_float(mask) initializer = ortho_init(np.sqrt(2)) - h = tf.contrib.layers.conv2d(input, - num_outputs=32, - kernel_size=8, - stride=4, - padding=pad, - weights_initializer=initializer, - **conv_kwargs) - h = tf.contrib.layers.conv2d(h, - num_outputs=64, - kernel_size=4, - stride=2, - padding=pad, - weights_initializer=initializer, - **conv_kwargs) - h = tf.contrib.layers.conv2d(h, - num_outputs=64, - kernel_size=3, - stride=1, - padding=pad, - weights_initializer=initializer, - **conv_kwargs) + h = nature_cnn(input, **conv_kwargs) h = tf.layers.flatten(h) h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer) @@ -78,17 +58,17 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs): h, next_state = lstm(h, mask, state, scope='lstm', nh=nlstm) return h, next_state - return state, _network_fn + return state, RNN(_network_fn) - return network_fn + return RNN(network_fn) -@register("ppo_cnn_lnlstm") +@register("ppo_cnn_lnlstm", is_rnn=True) def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs): return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) -@register("ppo_gru") +@register("ppo_gru", is_rnn=True) def ppo_gru(nlstm=128): def network_fn(input, mask): memory_size = nlstm @@ -107,12 +87,12 @@ def ppo_gru(nlstm=128): h, next_state = gru(input, mask, state, nh=nlstm) return h, next_state - return state, _network_fn + return state, RNN(_network_fn) - return network_fn + return RNN(network_fn) -@register("ppo_lstm_mlp") +@register("ppo_lstm_mlp", is_rnn=True) def ppo_lstm_mlp(nlstm=128, layer_norm=False): def network_fn(input, mask): memory_size = nlstm * 2 @@ -138,12 +118,12 @@ def ppo_lstm_mlp(nlstm=128, layer_norm=False): h = activation(h) return h, next_state - return state, _network_fn + return state, RNN(_network_fn) - return network_fn + return RNN(network_fn) -@register("ppo_gru_mlp") +@register("ppo_gru_mlp", is_rnn=True) def ppo_gru_mlp(nlstm=128): def network_fn(input, mask): memory_size = nlstm @@ -170,9 +150,9 @@ def ppo_gru_mlp(nlstm=128): return h, next_state - return state, _network_fn + return state, RNN(_network_fn) - return network_fn + return RNN(network_fn) def lstm(x, m, s, scope, nh, init_scale=1.0): diff --git a/baselines/ppo2/policies.py b/baselines/ppo2/policies.py index fb1476f..5e2b7e2 100644 --- a/baselines/ppo2/policies.py +++ b/baselines/ppo2/policies.py @@ -1,12 +1,12 @@ -import inspect - import gym import numpy as np import tensorflow as tf + from baselines.a2c.utils import fc from baselines.common import tf_util from baselines.common.distributions import make_pdtype from baselines.common.input import observation_placeholder, encode_observation +from baselines.common.models import RNN from baselines.common.models import get_network_builder from baselines.common.tf_util import adjust_shape @@ -125,7 +125,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False, encoded_x = encode_observation(ob_space, X) with tf.variable_scope('load_rnn_memory'): - if is_rnn_network(policy_network): + if isinstance(policy_network, RNN): policy_state, policy_network_ = policy_network(encoded_x, dones) else: policy_network_ = policy_network @@ -139,7 +139,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False, assert callable(value_network) value_network_ = value_network - if is_rnn_network(value_network_): + if isinstance(value_network_, RNN): value_state, value_network_ = value_network_(encoded_x, dones) if policy_state or value_state: @@ -154,7 +154,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False, index += size with tf.variable_scope('policy_latent', reuse=tf.AUTO_REUSE): - if is_rnn_network(policy_network_): + if isinstance(policy_network_, RNN): policy_latent, next_policy_state = \ policy_network_(encoded_x, dones, state_map[policy_state]) next_states_list.append(next_policy_state) @@ -164,7 +164,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False, with tf.variable_scope('value_latent', reuse=tf.AUTO_REUSE): if value_network_ == 'shared': value_latent = policy_latent - elif is_rnn_network(value_network_): + elif isinstance(value_network_, RNN): value_latent, next_value_state = \ value_network_(encoded_x, dones, state_map[value_state]) next_states_list.append(next_value_state) @@ -201,7 +201,3 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False, return policy return policy_fn - - -def is_rnn_network(network): - return 'mask' in inspect.getfullargspec(network).args