Files
baselines/baselines/common/tests/test_fetchreach.py
pzhokhov 6c44fb28fe refactor HER - phase 1 (#767)
* joshim5 changes (width and height to WarpFrame wrapper)

* match network output with action distribution via a linear layer only if necessary (#167)

* support color vs. grayscale option in WarpFrame wrapper (#166)

* support color vs. grayscale option in WarpFrame wrapper

* Support color in other wrappers

* Updated per Peters suggestions

* fixing test failures

* ppo2 with microbatches (#168)

* pass microbatch_size to the model during construction

* microbatch fixes and test (#169)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* Peterz joshim5 subclass ppo2 model (#170)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* subclassing the model to make microbatched version of model WIP

* made microbatched model a subclass of ppo2 Model

* flake8 complaint

* mpi-less ppo2 (resolving merge conflict)

* flake8 and mpi4py imports in ppo2/model.py

* more un-mpying

* merge master

* updates to the benchmark viewer code + autopep8 (#184)

* viz docs and syntactic sugar wip

* update viewer yaml to use persistent volume claims

* move plot_util to baselines.common, update links

* use 1Tb hard drive for results viewer

* small updates to benchmark vizualizer code

* autopep8

* autopep8

* any folder can be a benchmark

* massage games image a little bit

* fixed --preload option in app.py

* remove preload from run_viewer.sh

* remove pdb breakpoints

* update bench-viewer.yaml

* fixed bug (#185)

* fixed bug 

it's wrong to do the else statement, because no other nodes would start.

* changed the fix slightly

* Refactor her phase 1 (#194)

* add monitor to the rollout envs in her RUN BENCHMARKS her

* Slice -> Slide in her benchmarks RUN BENCHMARKS her

* run her benchmark for 200 epochs

* dummy commit to RUN BENCHMARKS her

* her benchmark for 500 epochs RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* disable saving of policies in her benchmark RUN BENCHMARKS her

* run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch

* run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch

* launcher refactor wip

* wip

* her works on FetchReach

* her runner refactor RUN BENCHMARKS Fetch1M

* unit test for her

* fixing warnings in mpi_average in her, skip test_fetchreach if mujoco is not present

* pickle-based serialization in her

* remove extra import from subproc_vec_env.py

* investigating differences in rollout.py

* try with old rollout code RUN BENCHMARKS her

* temporarily use DummyVecEnv in cmd_util.py RUN BENCHMARKS her

* dummy commit to RUN BENCHMARKS her

* set info_values in rollout worker in her RUN BENCHMARKS her

* bug in rollout_new.py RUN BENCHMARKS her

* fixed bug in rollout_new.py RUN BENCHMARKS her

* do not use last step because vecenv calls reset and returns obs after reset RUN BENCHMARKS her

* updated buffer sizes RUN BENCHMARKS her

* fixed loading/saving via joblib

* dust off learning from demonstrations in HER, docs, refactor

* add deprecation notice on her play and plot files

* address comments by Matthias
2018-12-19 14:44:08 -08:00

40 lines
822 B
Python

import pytest
import gym
from baselines.run import get_learn_function
from baselines.common.tests.util import reward_per_episode_test
pytest.importorskip('mujoco_py')
common_kwargs = dict(
network='mlp',
seed=0,
)
learn_kwargs = {
'her': dict(total_timesteps=2000)
}
@pytest.mark.slow
@pytest.mark.parametrize("alg", learn_kwargs.keys())
def test_fetchreach(alg):
'''
Test if the algorithm (with an mlp policy)
can learn the FetchReach task
'''
kwargs = common_kwargs.copy()
kwargs.update(learn_kwargs[alg])
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
def env_fn():
env = gym.make('FetchReach-v1')
env.seed(0)
return env
reward_per_episode_test(env_fn, learn_fn, -15)
if __name__ == '__main__':
test_fetchreach('her')