Compare commits
1 Commits
peterz_tfl
...
peterz_tfl
Author | SHA1 | Date | |
---|---|---|---|
|
dbcc4e0252 |
@@ -6,7 +6,8 @@ from baselines.run import get_learn_function
|
||||
|
||||
common_kwargs = dict(
|
||||
seed=0,
|
||||
total_timesteps=50000,
|
||||
total_timesteps=20000,
|
||||
nlstm=64
|
||||
)
|
||||
|
||||
learn_kwargs = {
|
||||
@@ -19,7 +20,7 @@ learn_kwargs = {
|
||||
|
||||
|
||||
alg_list = learn_kwargs.keys()
|
||||
rnn_list = ['lstm']
|
||||
rnn_list = ['lstm', 'tflstm', 'tflstm_static']
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alg", alg_list)
|
||||
@@ -41,11 +42,11 @@ def test_fixed_sequence(alg, rnn):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
simple_test(env_fn, learn, 0.7)
|
||||
simple_test(env_fn, learn, 0.3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fixed_sequence('ppo2', 'lstm')
|
||||
test_fixed_sequence('ppo2', 'tflstm')
|
||||
|
||||
|
||||
|
||||
|
@@ -2,6 +2,7 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
from gym.spaces import np_random
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.bench.monitor import Monitor
|
||||
|
||||
N_TRIALS = 10000
|
||||
N_EPISODES = 100
|
||||
@@ -10,7 +11,7 @@ def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||
np.random.seed(0)
|
||||
np_random.seed(0)
|
||||
|
||||
env = DummyVecEnv([env_fn])
|
||||
env = DummyVecEnv([lambda: Monitor(env_fn(), None, allow_early_resets=True)])
|
||||
|
||||
|
||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||
|
Reference in New Issue
Block a user