* joshim5 changes (width and height to WarpFrame wrapper) * match network output with action distribution via a linear layer only if necessary (#167) * support color vs. grayscale option in WarpFrame wrapper (#166) * support color vs. grayscale option in WarpFrame wrapper * Support color in other wrappers * Updated per Peters suggestions * fixing test failures * ppo2 with microbatches (#168) * pass microbatch_size to the model during construction * microbatch fixes and test (#169) * microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix * Peterz joshim5 subclass ppo2 model (#170) * microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix * subclassing the model to make microbatched version of model WIP * made microbatched model a subclass of ppo2 Model * flake8 complaint * mpi-less ppo2 (resolving merge conflict) * flake8 and mpi4py imports in ppo2/model.py * more un-mpying * merge master * updates to the benchmark viewer code + autopep8 (#184) * viz docs and syntactic sugar wip * update viewer yaml to use persistent volume claims * move plot_util to baselines.common, update links * use 1Tb hard drive for results viewer * small updates to benchmark vizualizer code * autopep8 * autopep8 * any folder can be a benchmark * massage games image a little bit * fixed --preload option in app.py * remove preload from run_viewer.sh * remove pdb breakpoints * update bench-viewer.yaml * fixed bug (#185) * fixed bug it's wrong to do the else statement, because no other nodes would start. * changed the fix slightly * Refactor her phase 1 (#194) * add monitor to the rollout envs in her RUN BENCHMARKS her * Slice -> Slide in her benchmarks RUN BENCHMARKS her * run her benchmark for 200 epochs * dummy commit to RUN BENCHMARKS her * her benchmark for 500 epochs RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * disable saving of policies in her benchmark RUN BENCHMARKS her * run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch * run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch * launcher refactor wip * wip * her works on FetchReach * her runner refactor RUN BENCHMARKS Fetch1M * unit test for her * fixing warnings in mpi_average in her, skip test_fetchreach if mujoco is not present * pickle-based serialization in her * remove extra import from subproc_vec_env.py * investigating differences in rollout.py * try with old rollout code RUN BENCHMARKS her * temporarily use DummyVecEnv in cmd_util.py RUN BENCHMARKS her * dummy commit to RUN BENCHMARKS her * set info_values in rollout worker in her RUN BENCHMARKS her * bug in rollout_new.py RUN BENCHMARKS her * fixed bug in rollout_new.py RUN BENCHMARKS her * do not use last step because vecenv calls reset and returns obs after reset RUN BENCHMARKS her * updated buffer sizes RUN BENCHMARKS her * fixed loading/saving via joblib * dust off learning from demonstrations in HER, docs, refactor * add deprecation notice on her play and plot files * address comments by Matthias
40 lines
822 B
Python
40 lines
822 B
Python
import pytest
|
|
import gym
|
|
|
|
from baselines.run import get_learn_function
|
|
from baselines.common.tests.util import reward_per_episode_test
|
|
|
|
pytest.importorskip('mujoco_py')
|
|
|
|
common_kwargs = dict(
|
|
network='mlp',
|
|
seed=0,
|
|
)
|
|
|
|
learn_kwargs = {
|
|
'her': dict(total_timesteps=2000)
|
|
}
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.parametrize("alg", learn_kwargs.keys())
|
|
def test_fetchreach(alg):
|
|
'''
|
|
Test if the algorithm (with an mlp policy)
|
|
can learn the FetchReach task
|
|
'''
|
|
|
|
kwargs = common_kwargs.copy()
|
|
kwargs.update(learn_kwargs[alg])
|
|
|
|
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
|
|
def env_fn():
|
|
|
|
env = gym.make('FetchReach-v1')
|
|
env.seed(0)
|
|
return env
|
|
|
|
reward_per_episode_test(env_fn, learn_fn, -15)
|
|
|
|
if __name__ == '__main__':
|
|
test_fetchreach('her')
|