* make acer use vecframestack * acer passes mnist test with 20k steps * acer with non-image observations and tests * flake8 * test acer serialization with non-recurrent policies
50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import pytest
|
|
|
|
# from baselines.acer import acer_simple as acer
|
|
from baselines.common.tests.envs.mnist_env import MnistEnv
|
|
from baselines.common.tests.util import simple_test
|
|
from baselines.run import get_learn_function
|
|
|
|
|
|
# TODO investigate a2c and ppo2 failures - is it due to bad hyperparameters for this problem?
|
|
# GitHub issue https://github.com/openai/baselines/issues/189
|
|
common_kwargs = {
|
|
'seed': 0,
|
|
'network':'cnn',
|
|
'gamma':0.9,
|
|
'pad':'SAME'
|
|
}
|
|
|
|
learn_args = {
|
|
'a2c': dict(total_timesteps=50000),
|
|
'acer': dict(total_timesteps=20000),
|
|
'deepq': dict(total_timesteps=5000),
|
|
'acktr': dict(total_timesteps=30000),
|
|
'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0),
|
|
'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001)
|
|
}
|
|
|
|
|
|
#tests pass, but are too slow on travis. Same algorithms are covered
|
|
# by other tests with less compute-hungry nn's and by benchmarks
|
|
@pytest.mark.skip
|
|
@pytest.mark.slow
|
|
@pytest.mark.parametrize("alg", learn_args.keys())
|
|
def test_mnist(alg):
|
|
'''
|
|
Test if the algorithm can learn to classify MNIST digits.
|
|
Uses CNN policy.
|
|
'''
|
|
|
|
learn_kwargs = learn_args[alg]
|
|
learn_kwargs.update(common_kwargs)
|
|
|
|
learn = get_learn_function(alg)
|
|
learn_fn = lambda e: learn(env=e, **learn_kwargs)
|
|
env_fn = lambda: MnistEnv(seed=0, episode_len=100)
|
|
|
|
simple_test(env_fn, learn_fn, 0.6)
|
|
|
|
if __name__ == '__main__':
|
|
test_mnist('acer')
|