add RNN class.

This commit is contained in:
gyunt
2019-04-08 18:35:05 +09:00
parent 1dbfbaac16
commit e6f0d98b68
3 changed files with 55 additions and 60 deletions

View File

@@ -1,18 +1,32 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.layers as layers
from baselines.a2c import utils from baselines.a2c import utils
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
from baselines.common.mpi_running_mean_std import RunningMeanStd from baselines.common.mpi_running_mean_std import RunningMeanStd
import tensorflow.contrib.layers as layers
mapping = {} mapping = {}
def register(name):
def register(name, is_rnn=False):
def _thunk(func): def _thunk(func):
if is_rnn:
func = RNN(func)
mapping[name] = func mapping[name] = func
return func return func
return _thunk return _thunk
class RNN(object):
def __init__(self, func):
self._func = func
def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)
def nature_cnn(unscaled_images, **conv_kwargs): def nature_cnn(unscaled_images, **conv_kwargs):
""" """
CNN from Nature paper. CNN from Nature paper.
@@ -46,6 +60,7 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
function that builds fully connected network with a given input tensor / placeholder function that builds fully connected network with a given input tensor / placeholder
""" """
def network_fn(X): def network_fn(X):
h = tf.layers.flatten(X) h = tf.layers.flatten(X)
for i in range(num_layers): for i in range(num_layers):
@@ -63,6 +78,7 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
def cnn(**conv_kwargs): def cnn(**conv_kwargs):
def network_fn(X): def network_fn(X):
return nature_cnn(X, **conv_kwargs) return nature_cnn(X, **conv_kwargs)
return network_fn return network_fn
@@ -77,10 +93,11 @@ def cnn_small(**conv_kwargs):
h = conv_to_fc(h) h = conv_to_fc(h)
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2))) h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
return h return h
return network_fn return network_fn
@register("lstm") @register("lstm", is_rnn=True)
def lstm(nlstm=128, layer_norm=False): def lstm(nlstm=128, layer_norm=False):
""" """
Builds LSTM (Long-Short Term Memory) network to be used in a policy. Builds LSTM (Long-Short Term Memory) network to be used in a policy.
@@ -135,7 +152,7 @@ def lstm(nlstm=128, layer_norm=False):
return network_fn return network_fn
@register("cnn_lstm") @register("cnn_lstm", is_rnn=True)
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(X, nenv=1): def network_fn(X, nenv=1):
nbatch = X.shape[0] nbatch = X.shape[0]
@@ -162,7 +179,7 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
return network_fn return network_fn
@register("cnn_lnlstm") @register("cnn_lnlstm", is_rnn=True)
def cnn_lnlstm(nlstm=128, **conv_kwargs): def cnn_lnlstm(nlstm=128, **conv_kwargs):
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
@@ -195,8 +212,10 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
**conv_kwargs) **conv_kwargs)
return out return out
return network_fn return network_fn
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]): def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
rms = RunningMeanStd(shape=x.shape[1:]) rms = RunningMeanStd(shape=x.shape[1:])
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range)) norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))

View File

@@ -2,10 +2,10 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from baselines.a2c.utils import ortho_init, fc from baselines.a2c.utils import ortho_init, fc
from baselines.common.models import register from baselines.common.models import register, nature_cnn, RNN
@register("ppo_lstm") @register("ppo_lstm", is_rnn=True)
def ppo_lstm(nlstm=128, layer_norm=False): def ppo_lstm(nlstm=128, layer_norm=False):
def network_fn(input, mask): def network_fn(input, mask):
memory_size = nlstm * 2 memory_size = nlstm * 2
@@ -27,13 +27,13 @@ def ppo_lstm(nlstm=128, layer_norm=False):
h, next_state = lstm(input, mask, state, scope='lstm', nh=nlstm) h, next_state = lstm(input, mask, state, scope='lstm', nh=nlstm)
return h, next_state return h, next_state
return state, _network_fn return state, RNN(_network_fn)
return network_fn return RNN(network_fn)
@register("ppo_cnn_lstm") @register("ppo_cnn_lstm", is_rnn=True)
def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs): def ppo_cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(input, mask): def network_fn(input, mask):
memory_size = nlstm * 2 memory_size = nlstm * 2
nbatch = input.shape[0] nbatch = input.shape[0]
@@ -48,27 +48,7 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
mask = tf.to_float(mask) mask = tf.to_float(mask)
initializer = ortho_init(np.sqrt(2)) initializer = ortho_init(np.sqrt(2))
h = tf.contrib.layers.conv2d(input, h = nature_cnn(input, **conv_kwargs)
num_outputs=32,
kernel_size=8,
stride=4,
padding=pad,
weights_initializer=initializer,
**conv_kwargs)
h = tf.contrib.layers.conv2d(h,
num_outputs=64,
kernel_size=4,
stride=2,
padding=pad,
weights_initializer=initializer,
**conv_kwargs)
h = tf.contrib.layers.conv2d(h,
num_outputs=64,
kernel_size=3,
stride=1,
padding=pad,
weights_initializer=initializer,
**conv_kwargs)
h = tf.layers.flatten(h) h = tf.layers.flatten(h)
h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer) h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer)
@@ -78,17 +58,17 @@ def ppo_cnn_lstm(nlstm=128, layer_norm=False, pad='VALID', **conv_kwargs):
h, next_state = lstm(h, mask, state, scope='lstm', nh=nlstm) h, next_state = lstm(h, mask, state, scope='lstm', nh=nlstm)
return h, next_state return h, next_state
return state, _network_fn return state, RNN(_network_fn)
return network_fn return RNN(network_fn)
@register("ppo_cnn_lnlstm") @register("ppo_cnn_lnlstm", is_rnn=True)
def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs): def ppo_cnn_lnlstm(nlstm=128, **conv_kwargs):
return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) return ppo_cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
@register("ppo_gru") @register("ppo_gru", is_rnn=True)
def ppo_gru(nlstm=128): def ppo_gru(nlstm=128):
def network_fn(input, mask): def network_fn(input, mask):
memory_size = nlstm memory_size = nlstm
@@ -107,12 +87,12 @@ def ppo_gru(nlstm=128):
h, next_state = gru(input, mask, state, nh=nlstm) h, next_state = gru(input, mask, state, nh=nlstm)
return h, next_state return h, next_state
return state, _network_fn return state, RNN(_network_fn)
return network_fn return RNN(network_fn)
@register("ppo_lstm_mlp") @register("ppo_lstm_mlp", is_rnn=True)
def ppo_lstm_mlp(nlstm=128, layer_norm=False): def ppo_lstm_mlp(nlstm=128, layer_norm=False):
def network_fn(input, mask): def network_fn(input, mask):
memory_size = nlstm * 2 memory_size = nlstm * 2
@@ -138,12 +118,12 @@ def ppo_lstm_mlp(nlstm=128, layer_norm=False):
h = activation(h) h = activation(h)
return h, next_state return h, next_state
return state, _network_fn return state, RNN(_network_fn)
return network_fn return RNN(network_fn)
@register("ppo_gru_mlp") @register("ppo_gru_mlp", is_rnn=True)
def ppo_gru_mlp(nlstm=128): def ppo_gru_mlp(nlstm=128):
def network_fn(input, mask): def network_fn(input, mask):
memory_size = nlstm memory_size = nlstm
@@ -170,9 +150,9 @@ def ppo_gru_mlp(nlstm=128):
return h, next_state return h, next_state
return state, _network_fn return state, RNN(_network_fn)
return network_fn return RNN(network_fn)
def lstm(x, m, s, scope, nh, init_scale=1.0): def lstm(x, m, s, scope, nh, init_scale=1.0):

View File

@@ -1,12 +1,12 @@
import inspect
import gym import gym
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from baselines.a2c.utils import fc from baselines.a2c.utils import fc
from baselines.common import tf_util from baselines.common import tf_util
from baselines.common.distributions import make_pdtype from baselines.common.distributions import make_pdtype
from baselines.common.input import observation_placeholder, encode_observation from baselines.common.input import observation_placeholder, encode_observation
from baselines.common.models import RNN
from baselines.common.models import get_network_builder from baselines.common.models import get_network_builder
from baselines.common.tf_util import adjust_shape from baselines.common.tf_util import adjust_shape
@@ -125,7 +125,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
encoded_x = encode_observation(ob_space, X) encoded_x = encode_observation(ob_space, X)
with tf.variable_scope('load_rnn_memory'): with tf.variable_scope('load_rnn_memory'):
if is_rnn_network(policy_network): if isinstance(policy_network, RNN):
policy_state, policy_network_ = policy_network(encoded_x, dones) policy_state, policy_network_ = policy_network(encoded_x, dones)
else: else:
policy_network_ = policy_network policy_network_ = policy_network
@@ -139,7 +139,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
assert callable(value_network) assert callable(value_network)
value_network_ = value_network value_network_ = value_network
if is_rnn_network(value_network_): if isinstance(value_network_, RNN):
value_state, value_network_ = value_network_(encoded_x, dones) value_state, value_network_ = value_network_(encoded_x, dones)
if policy_state or value_state: if policy_state or value_state:
@@ -154,7 +154,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
index += size index += size
with tf.variable_scope('policy_latent', reuse=tf.AUTO_REUSE): with tf.variable_scope('policy_latent', reuse=tf.AUTO_REUSE):
if is_rnn_network(policy_network_): if isinstance(policy_network_, RNN):
policy_latent, next_policy_state = \ policy_latent, next_policy_state = \
policy_network_(encoded_x, dones, state_map[policy_state]) policy_network_(encoded_x, dones, state_map[policy_state])
next_states_list.append(next_policy_state) next_states_list.append(next_policy_state)
@@ -164,7 +164,7 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
with tf.variable_scope('value_latent', reuse=tf.AUTO_REUSE): with tf.variable_scope('value_latent', reuse=tf.AUTO_REUSE):
if value_network_ == 'shared': if value_network_ == 'shared':
value_latent = policy_latent value_latent = policy_latent
elif is_rnn_network(value_network_): elif isinstance(value_network_, RNN):
value_latent, next_value_state = \ value_latent, next_value_state = \
value_network_(encoded_x, dones, state_map[value_state]) value_network_(encoded_x, dones, state_map[value_state])
next_states_list.append(next_value_state) next_states_list.append(next_value_state)
@@ -201,7 +201,3 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
return policy return policy
return policy_fn return policy_fn
def is_rnn_network(network):
return 'mask' in inspect.getfullargspec(network).args