move RNN class to baselines/ppo2/layers.py' and revert baselines/common/models.py` to 858afa8.

This commit is contained in:
gyunt
2019-04-09 01:53:10 +09:00
parent b6e6c5201a
commit bb2523f54d
3 changed files with 46 additions and 66 deletions

View File

@@ -1,33 +1,18 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.layers as layers
from baselines.a2c import utils from baselines.a2c import utils
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
from baselines.common.mpi_running_mean_std import RunningMeanStd from baselines.common.mpi_running_mean_std import RunningMeanStd
import tensorflow.contrib.layers as layers
mapping = {} mapping = {}
def register(name):
def register(name, is_rnn=False):
def _thunk(func): def _thunk(func):
if is_rnn:
func = RNN(func)
mapping[name] = func mapping[name] = func
return func return func
return _thunk return _thunk
class RNN(object):
def __init__(self, func, memory_size=None):
self._func = func
self.memory_size = memory_size
def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)
def nature_cnn(unscaled_images, **conv_kwargs): def nature_cnn(unscaled_images, **conv_kwargs):
""" """
CNN from Nature paper. CNN from Nature paper.
@@ -61,7 +46,6 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
function that builds fully connected network with a given input tensor / placeholder function that builds fully connected network with a given input tensor / placeholder
""" """
def network_fn(X): def network_fn(X):
h = tf.layers.flatten(X) h = tf.layers.flatten(X)
for i in range(num_layers): for i in range(num_layers):
@@ -79,7 +63,6 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
def cnn(**conv_kwargs): def cnn(**conv_kwargs):
def network_fn(X): def network_fn(X):
return nature_cnn(X, **conv_kwargs) return nature_cnn(X, **conv_kwargs)
return network_fn return network_fn
@@ -94,11 +77,10 @@ def cnn_small(**conv_kwargs):
h = conv_to_fc(h) h = conv_to_fc(h)
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2))) h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
return h return h
return network_fn return network_fn
@register("lstm", is_rnn=True) @register("lstm")
def lstm(nlstm=128, layer_norm=False): def lstm(nlstm=128, layer_norm=False):
""" """
Builds LSTM (Long-Short Term Memory) network to be used in a policy. Builds LSTM (Long-Short Term Memory) network to be used in a policy.
@@ -134,8 +116,8 @@ def lstm(nlstm=128, layer_norm=False):
h = tf.layers.flatten(X) h = tf.layers.flatten(X)
M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
xs = batch_to_seq(h, nenv, nsteps) xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps)
@@ -148,12 +130,12 @@ def lstm(nlstm=128, layer_norm=False):
h = seq_to_batch(h5) h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float) initial_state = np.zeros(S.shape.as_list(), dtype=float)
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state} return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
return network_fn return network_fn
@register("cnn_lstm", is_rnn=True) @register("cnn_lstm")
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs): def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(X, nenv=1): def network_fn(X, nenv=1):
nbatch = X.shape[0] nbatch = X.shape[0]
@@ -161,8 +143,8 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
h = nature_cnn(X, **conv_kwargs) h = nature_cnn(X, **conv_kwargs)
M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
xs = batch_to_seq(h, nenv, nsteps) xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps)
@@ -175,12 +157,12 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
h = seq_to_batch(h5) h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float) initial_state = np.zeros(S.shape.as_list(), dtype=float)
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state} return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
return network_fn return network_fn
@register("cnn_lnlstm", is_rnn=True) @register("cnn_lnlstm")
def cnn_lnlstm(nlstm=128, **conv_kwargs): def cnn_lnlstm(nlstm=128, **conv_kwargs):
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs) return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
@@ -213,10 +195,8 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
**conv_kwargs) **conv_kwargs)
return out return out
return network_fn return network_fn
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]): def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
rms = RunningMeanStd(shape=x.shape[1:]) rms = RunningMeanStd(shape=x.shape[1:])
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range)) norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))

View File

@@ -2,12 +2,21 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from baselines.a2c.utils import ortho_init, fc, lstm, lnlstm from baselines.a2c.utils import ortho_init, fc, lstm, lnlstm
from baselines.common.models import register, nature_cnn, RNN from baselines.common.models import register, nature_cnn
@register("ppo_lstm", is_rnn=True) class RNN(object):
def __init__(self, func, memory_size=None):
self._func = func
self.memory_size = memory_size
def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)
@register("ppo_lstm")
def ppo_lstm(num_units=128, layer_norm=False): def ppo_lstm(num_units=128, layer_norm=False):
def _network_fn(input, mask, state): def network_fn(input, mask, state):
input = tf.layers.flatten(input) input = tf.layers.flatten(input)
mask = tf.to_float(mask) mask = tf.to_float(mask)
@@ -18,12 +27,12 @@ def ppo_lstm(num_units=128, layer_norm=False):
h = h[0] h = h[0]
return h, next_state return h, next_state
return RNN(_network_fn, memory_size=num_units * 2) return RNN(network_fn, memory_size=num_units * 2)
@register("ppo_cnn_lstm", is_rnn=True) @register("ppo_cnn_lstm")
def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs): def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
def _network_fn(input, mask, state): def network_fn(input, mask, state):
mask = tf.to_float(mask) mask = tf.to_float(mask)
initializer = ortho_init(np.sqrt(2)) initializer = ortho_init(np.sqrt(2))
@@ -38,41 +47,32 @@ def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
h = h[0] h = h[0]
return h, next_state return h, next_state
return RNN(_network_fn, memory_size=num_units * 2) return RNN(network_fn, memory_size=num_units * 2)
@register("ppo_cnn_lnlstm", is_rnn=True) @register("ppo_cnn_lnlstm")
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs): def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs) return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
@register("ppo_lstm_mlp", is_rnn=True) @register("ppo_lstm_mlp")
def ppo_lstm_mlp(num_units=128, layer_norm=False): def ppo_lstm_mlp(num_units=128, layer_norm=False):
def network_fn(input, mask): def _network_fn(input, mask, state):
memory_size = num_units * 2 h = tf.layers.flatten(input)
nbatch = input.shape[0] mask = tf.to_float(mask)
mask.get_shape().assert_is_compatible_with([nbatch])
state = tf.Variable(np.zeros([nbatch, memory_size]),
name='lstm_state',
trainable=False,
dtype=tf.float32,
collections=[tf.GraphKeys.LOCAL_VARIABLES])
def _network_fn(input, mask, state): if layer_norm:
input = tf.layers.flatten(input) h, next_state = lnlstm([h], [mask[:, None]], state, scope='lnlstm', nh=num_units)
mask = tf.to_float(mask) else:
h, next_state = lstm([h], [mask[:, None]], state, scope='lstm', nh=num_units)
h = h[0]
h, next_state = lstm([input], [mask[:, None]], state, scope='lstm', nh=num_units) num_layers = 2
h = h[0] num_hidden = 64
activation = tf.nn.relu
for i in range(num_layers):
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
h = activation(h)
return h, next_state
num_layers = 2 return RNN(_network_fn, num_units * 2)
num_hidden = 64
activation = tf.nn.relu
for i in range(num_layers):
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
h = activation(h)
return h, next_state
return state, RNN(_network_fn)
return RNN(network_fn)

View File

@@ -6,9 +6,9 @@ from baselines.a2c.utils import fc
from baselines.common import tf_util from baselines.common import tf_util
from baselines.common.distributions import make_pdtype from baselines.common.distributions import make_pdtype
from baselines.common.input import observation_placeholder, encode_observation from baselines.common.input import observation_placeholder, encode_observation
from baselines.common.models import RNN
from baselines.common.models import get_network_builder from baselines.common.models import get_network_builder
from baselines.common.tf_util import adjust_shape from baselines.common.tf_util import adjust_shape
from baselines.ppo2.layers import RNN
class PolicyWithValue(object): class PolicyWithValue(object):