move RNN class to baselines/ppo2/layers.py' and revert
baselines/common/models.py` to 858afa8
.
This commit is contained in:
@@ -1,33 +1,18 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.contrib.layers as layers
|
|
||||||
|
|
||||||
from baselines.a2c import utils
|
from baselines.a2c import utils
|
||||||
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
|
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
|
||||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||||
|
import tensorflow.contrib.layers as layers
|
||||||
|
|
||||||
mapping = {}
|
mapping = {}
|
||||||
|
|
||||||
|
def register(name):
|
||||||
def register(name, is_rnn=False):
|
|
||||||
def _thunk(func):
|
def _thunk(func):
|
||||||
if is_rnn:
|
|
||||||
func = RNN(func)
|
|
||||||
mapping[name] = func
|
mapping[name] = func
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return _thunk
|
return _thunk
|
||||||
|
|
||||||
|
|
||||||
class RNN(object):
|
|
||||||
def __init__(self, func, memory_size=None):
|
|
||||||
self._func = func
|
|
||||||
self.memory_size = memory_size
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self._func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def nature_cnn(unscaled_images, **conv_kwargs):
|
def nature_cnn(unscaled_images, **conv_kwargs):
|
||||||
"""
|
"""
|
||||||
CNN from Nature paper.
|
CNN from Nature paper.
|
||||||
@@ -61,7 +46,6 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
|
|||||||
|
|
||||||
function that builds fully connected network with a given input tensor / placeholder
|
function that builds fully connected network with a given input tensor / placeholder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
h = tf.layers.flatten(X)
|
h = tf.layers.flatten(X)
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
@@ -79,7 +63,6 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
|
|||||||
def cnn(**conv_kwargs):
|
def cnn(**conv_kwargs):
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
return nature_cnn(X, **conv_kwargs)
|
return nature_cnn(X, **conv_kwargs)
|
||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
@@ -94,11 +77,10 @@ def cnn_small(**conv_kwargs):
|
|||||||
h = conv_to_fc(h)
|
h = conv_to_fc(h)
|
||||||
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
|
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
|
||||||
return h
|
return h
|
||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
@register("lstm", is_rnn=True)
|
@register("lstm")
|
||||||
def lstm(nlstm=128, layer_norm=False):
|
def lstm(nlstm=128, layer_norm=False):
|
||||||
"""
|
"""
|
||||||
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
|
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
|
||||||
@@ -134,8 +116,8 @@ def lstm(nlstm=128, layer_norm=False):
|
|||||||
|
|
||||||
h = tf.layers.flatten(X)
|
h = tf.layers.flatten(X)
|
||||||
|
|
||||||
M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1)
|
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||||
S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
|
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
|
||||||
|
|
||||||
xs = batch_to_seq(h, nenv, nsteps)
|
xs = batch_to_seq(h, nenv, nsteps)
|
||||||
ms = batch_to_seq(M, nenv, nsteps)
|
ms = batch_to_seq(M, nenv, nsteps)
|
||||||
@@ -148,12 +130,12 @@ def lstm(nlstm=128, layer_norm=False):
|
|||||||
h = seq_to_batch(h5)
|
h = seq_to_batch(h5)
|
||||||
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||||
|
|
||||||
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
|
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
@register("cnn_lstm", is_rnn=True)
|
@register("cnn_lstm")
|
||||||
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
||||||
def network_fn(X, nenv=1):
|
def network_fn(X, nenv=1):
|
||||||
nbatch = X.shape[0]
|
nbatch = X.shape[0]
|
||||||
@@ -161,8 +143,8 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
|||||||
|
|
||||||
h = nature_cnn(X, **conv_kwargs)
|
h = nature_cnn(X, **conv_kwargs)
|
||||||
|
|
||||||
M = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1)
|
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
||||||
S = tf.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
|
S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
|
||||||
|
|
||||||
xs = batch_to_seq(h, nenv, nsteps)
|
xs = batch_to_seq(h, nenv, nsteps)
|
||||||
ms = batch_to_seq(M, nenv, nsteps)
|
ms = batch_to_seq(M, nenv, nsteps)
|
||||||
@@ -175,12 +157,12 @@ def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
|
|||||||
h = seq_to_batch(h5)
|
h = seq_to_batch(h5)
|
||||||
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
||||||
|
|
||||||
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
|
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
|
||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
@register("cnn_lnlstm", is_rnn=True)
|
@register("cnn_lnlstm")
|
||||||
def cnn_lnlstm(nlstm=128, **conv_kwargs):
|
def cnn_lnlstm(nlstm=128, **conv_kwargs):
|
||||||
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
||||||
|
|
||||||
@@ -213,10 +195,8 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
|
|||||||
**conv_kwargs)
|
**conv_kwargs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
return network_fn
|
return network_fn
|
||||||
|
|
||||||
|
|
||||||
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
||||||
rms = RunningMeanStd(shape=x.shape[1:])
|
rms = RunningMeanStd(shape=x.shape[1:])
|
||||||
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
||||||
|
@@ -2,12 +2,21 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from baselines.a2c.utils import ortho_init, fc, lstm, lnlstm
|
from baselines.a2c.utils import ortho_init, fc, lstm, lnlstm
|
||||||
from baselines.common.models import register, nature_cnn, RNN
|
from baselines.common.models import register, nature_cnn
|
||||||
|
|
||||||
|
|
||||||
@register("ppo_lstm", is_rnn=True)
|
class RNN(object):
|
||||||
|
def __init__(self, func, memory_size=None):
|
||||||
|
self._func = func
|
||||||
|
self.memory_size = memory_size
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self._func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register("ppo_lstm")
|
||||||
def ppo_lstm(num_units=128, layer_norm=False):
|
def ppo_lstm(num_units=128, layer_norm=False):
|
||||||
def _network_fn(input, mask, state):
|
def network_fn(input, mask, state):
|
||||||
input = tf.layers.flatten(input)
|
input = tf.layers.flatten(input)
|
||||||
mask = tf.to_float(mask)
|
mask = tf.to_float(mask)
|
||||||
|
|
||||||
@@ -18,12 +27,12 @@ def ppo_lstm(num_units=128, layer_norm=False):
|
|||||||
h = h[0]
|
h = h[0]
|
||||||
return h, next_state
|
return h, next_state
|
||||||
|
|
||||||
return RNN(_network_fn, memory_size=num_units * 2)
|
return RNN(network_fn, memory_size=num_units * 2)
|
||||||
|
|
||||||
|
|
||||||
@register("ppo_cnn_lstm", is_rnn=True)
|
@register("ppo_cnn_lstm")
|
||||||
def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
|
def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
|
||||||
def _network_fn(input, mask, state):
|
def network_fn(input, mask, state):
|
||||||
mask = tf.to_float(mask)
|
mask = tf.to_float(mask)
|
||||||
initializer = ortho_init(np.sqrt(2))
|
initializer = ortho_init(np.sqrt(2))
|
||||||
|
|
||||||
@@ -38,41 +47,32 @@ def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
|
|||||||
h = h[0]
|
h = h[0]
|
||||||
return h, next_state
|
return h, next_state
|
||||||
|
|
||||||
return RNN(_network_fn, memory_size=num_units * 2)
|
return RNN(network_fn, memory_size=num_units * 2)
|
||||||
|
|
||||||
|
|
||||||
@register("ppo_cnn_lnlstm", is_rnn=True)
|
@register("ppo_cnn_lnlstm")
|
||||||
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
|
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
|
||||||
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
|
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register("ppo_lstm_mlp", is_rnn=True)
|
@register("ppo_lstm_mlp")
|
||||||
def ppo_lstm_mlp(num_units=128, layer_norm=False):
|
def ppo_lstm_mlp(num_units=128, layer_norm=False):
|
||||||
def network_fn(input, mask):
|
def _network_fn(input, mask, state):
|
||||||
memory_size = num_units * 2
|
h = tf.layers.flatten(input)
|
||||||
nbatch = input.shape[0]
|
mask = tf.to_float(mask)
|
||||||
mask.get_shape().assert_is_compatible_with([nbatch])
|
|
||||||
state = tf.Variable(np.zeros([nbatch, memory_size]),
|
|
||||||
name='lstm_state',
|
|
||||||
trainable=False,
|
|
||||||
dtype=tf.float32,
|
|
||||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
|
||||||
|
|
||||||
def _network_fn(input, mask, state):
|
if layer_norm:
|
||||||
input = tf.layers.flatten(input)
|
h, next_state = lnlstm([h], [mask[:, None]], state, scope='lnlstm', nh=num_units)
|
||||||
mask = tf.to_float(mask)
|
else:
|
||||||
|
h, next_state = lstm([h], [mask[:, None]], state, scope='lstm', nh=num_units)
|
||||||
|
h = h[0]
|
||||||
|
|
||||||
h, next_state = lstm([input], [mask[:, None]], state, scope='lstm', nh=num_units)
|
num_layers = 2
|
||||||
h = h[0]
|
num_hidden = 64
|
||||||
|
activation = tf.nn.relu
|
||||||
|
for i in range(num_layers):
|
||||||
|
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
||||||
|
h = activation(h)
|
||||||
|
return h, next_state
|
||||||
|
|
||||||
num_layers = 2
|
return RNN(_network_fn, num_units * 2)
|
||||||
num_hidden = 64
|
|
||||||
activation = tf.nn.relu
|
|
||||||
for i in range(num_layers):
|
|
||||||
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
|
||||||
h = activation(h)
|
|
||||||
return h, next_state
|
|
||||||
|
|
||||||
return state, RNN(_network_fn)
|
|
||||||
|
|
||||||
return RNN(network_fn)
|
|
||||||
|
@@ -6,9 +6,9 @@ from baselines.a2c.utils import fc
|
|||||||
from baselines.common import tf_util
|
from baselines.common import tf_util
|
||||||
from baselines.common.distributions import make_pdtype
|
from baselines.common.distributions import make_pdtype
|
||||||
from baselines.common.input import observation_placeholder, encode_observation
|
from baselines.common.input import observation_placeholder, encode_observation
|
||||||
from baselines.common.models import RNN
|
|
||||||
from baselines.common.models import get_network_builder
|
from baselines.common.models import get_network_builder
|
||||||
from baselines.common.tf_util import adjust_shape
|
from baselines.common.tf_util import adjust_shape
|
||||||
|
from baselines.ppo2.layers import RNN
|
||||||
|
|
||||||
|
|
||||||
class PolicyWithValue(object):
|
class PolicyWithValue(object):
|
||||||
|
Reference in New Issue
Block a user