Compare commits

..

2 Commits

Author SHA1 Message Date
Peter Zhokhov
b650cd862e lstm network builders using tf lstm 2018-08-10 14:24:55 -07:00
Peter Zhokhov
217b111c88 merged refactor 2018-08-10 14:14:46 -07:00
6 changed files with 56 additions and 10 deletions

View File

@@ -1 +1 @@
ppo2

View File

@@ -139,4 +139,3 @@ To cite this repository in publications:
journal = {GitHub repository},
howpublished = {\url{https://github.com/openai/baselines}},
}

View File

@@ -156,7 +156,7 @@ class FrameStack(gym.Wrapper):
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
def reset(self):
ob = self.env.reset()
@@ -176,7 +176,6 @@ class FrameStack(gym.Wrapper):
class ScaledFloatFrame(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
def observation(self, observation):
# careful! This undoes the memory optimization, use

View File

@@ -92,6 +92,48 @@ def lstm(nlstm=128, layer_norm=False):
return network_fn
def tflstm_static(nlstm=128, layer_norm=False):
def network_fn(X, nenv=1):
nbatch = X.shape[0]
nsteps = nbatch // nenv
h = tf.layers.flatten(X)
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(nlstm, state_is_tuple=False, forget_bias=0.0)
S = tf.placeholder(tf.float32, rnn_cell.zero_state(nenv, dtype=tf.float32).shape) #states
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
xs = batch_to_seq(h, nenv, nsteps)
h5, snew = tf.nn.static_rnn(rnn_cell, xs, initial_state=S)
h = seq_to_batch(h5)
initial_state = np.zeros(S.shape.as_list(), dtype=float)
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
return network_fn
def tflstm(nlstm=128):
def network_fn(X, nenv=1):
nbatch = X.shape[0]
nsteps = nbatch // nenv
h = tf.layers.flatten(X)
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(nlstm, state_is_tuple=False, forget_bias=0.0)
S = tf.placeholder(tf.float32, rnn_cell.zero_state(nenv, dtype=tf.float32).shape) #states
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
initial_state = np.zeros(S.shape)
h = tf.reshape(h, (-1, nsteps, h.shape[-1]))
h, snew = tf.nn.dynamic_rnn(rnn_cell, h, initial_state=S)
h = tf.reshape(h, (-1, h.shape[-1]))
return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
return network_fn
def cnn_lstm(nlstm=128, layer_norm=False, **conv_kwargs):
def network_fn(X, nenv=1):
@@ -138,7 +180,7 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
'''
def network_fn(X):
out = tf.cast(X, tf.float32) / 255.
out = X
with tf.variable_scope("convnet"):
for num_outputs, kernel_size, stride in convs:
out = layers.convolution2d(out,
@@ -169,6 +211,10 @@ def get_network_builder(name):
return mlp
elif name == 'lstm':
return lstm
elif name == 'tflstm_static':
return tflstm_static
elif name == 'tflstm':
return tflstm
elif name == 'cnn_lstm':
return cnn_lstm
elif name == 'cnn_lnlstm':

View File

@@ -6,7 +6,8 @@ from baselines.run import get_learn_function
common_kwargs = dict(
seed=0,
total_timesteps=50000,
total_timesteps=20000,
nlstm=64
)
learn_kwargs = {
@@ -19,7 +20,7 @@ learn_kwargs = {
alg_list = learn_kwargs.keys()
rnn_list = ['lstm']
rnn_list = ['lstm', 'tflstm', 'tflstm_static']
@pytest.mark.slow
@pytest.mark.parametrize("alg", alg_list)
@@ -41,11 +42,11 @@ def test_fixed_sequence(alg, rnn):
**kwargs
)
simple_test(env_fn, learn, 0.7)
simple_test(env_fn, learn, 0.3)
if __name__ == '__main__':
test_fixed_sequence('ppo2', 'lstm')
test_fixed_sequence('ppo2', 'tflstm')

View File

@@ -2,6 +2,7 @@ import tensorflow as tf
import numpy as np
from gym.spaces import np_random
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.bench.monitor import Monitor
N_TRIALS = 10000
N_EPISODES = 100
@@ -10,7 +11,7 @@ def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
np.random.seed(0)
np_random.seed(0)
env = DummyVecEnv([env_fn])
env = DummyVecEnv([lambda: Monitor(env_fn(), None, allow_early_resets=True)])
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():