From 1d56af90d3b7af545ff1afd0b9cb1fe23e44da89 Mon Sep 17 00:00:00 2001 From: Karl Cobbe Date: Thu, 24 Jan 2019 10:10:59 -0800 Subject: [PATCH] Vecenv refactor (#223) * update karl util * restore pvi flag * change rcall auto cpu behavior, move gin.configurable, add os.makedirs * vecenv refactor * aux buf index fix * add num aux obs * reset level with enter * restore high difficulty flag * bugfix * restore train_coinrun.py * tweaks * renaming * renaming * better arguments handling * more options * options cleanup * game data refactor * more options * args for train_procgen * add close handler to interactive base class * use debug build if debug=True, fix range on aux_obs * add ProcGenEnv to __init__.py, add missing imports to procgen.py * export RemoveDictWrapper and build, update train_procgen.py, move assets download into env creation and replace init_assets_and_build with just build * fix formatting issues * only call global init once * fix path in setup.py * revert part of makefile * ignore IDE files and folders * vec remove dict * export VecRemoveDictObs * remove RemoveDictWrapper * remove IDE files * move shared .h and .cpp files to common folder, update build to use those, dedupe env.cpp * fix missing header * try unified build function * remove old scripts dir * add comment on build * upload libenv with render fixes * tell qthreads to die when we unload the library * pyglet.app.run is garbage * static fixes * whoops * actually vsync is on * cleanup * cleanup * extern C for libenv interface * parse util rcall arg * high difficulty fix * game type enums * ProcGenEnv subclasses * game type cleanup * unrecognized key * unrecognized game type * parse util reorg * args management * typo fix * GinParser * arg tweaks * tweak * restore start_level/num_levels setting * fix create_procgen_env interface * build fix * procgen args in init signature * fix * build fix * fix logger usage in ppo_metal/run_retro --- baselines/common/vec_env/__init__.py | 3 ++- .../common/vec_env/vec_remove_dict_obs.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 baselines/common/vec_env/vec_remove_dict_obs.py diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index 5155650..ff6a305 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -5,5 +5,6 @@ from .subproc_vec_env import SubprocVecEnv from .vec_frame_stack import VecFrameStack from .vec_monitor import VecMonitor from .vec_normalize import VecNormalize +from .vec_remove_dict_obs import VecRemoveDictObs -__all__ = ['AlreadySteppingError', 'NotSteppingError', 'VecEnv', 'VecEnvWrapper', 'VecEnvObservationWrapper', 'CloudpickleWrapper', 'DummyVecEnv', 'ShmemVecEnv', 'SubprocVecEnv', 'VecFrameStack', 'VecMonitor', 'VecNormalize'] +__all__ = ['AlreadySteppingError', 'NotSteppingError', 'VecEnv', 'VecEnvWrapper', 'VecEnvObservationWrapper', 'CloudpickleWrapper', 'DummyVecEnv', 'ShmemVecEnv', 'SubprocVecEnv', 'VecFrameStack', 'VecMonitor', 'VecNormalize', 'VecRemoveDictObs'] diff --git a/baselines/common/vec_env/vec_remove_dict_obs.py b/baselines/common/vec_env/vec_remove_dict_obs.py new file mode 100644 index 0000000..f62d387 --- /dev/null +++ b/baselines/common/vec_env/vec_remove_dict_obs.py @@ -0,0 +1,22 @@ +from .vec_env import VecEnvWrapper + + +class VecRemoveDictObs(VecEnvWrapper): + """ + PPO2 doesn't support dictionary observations, so make the environment only expose the observation for the provided key. + """ + + def __init__(self, venv, key): + self._key = key + self._venv = venv + super().__init__(venv, observation_space=venv.observation_space.spaces[self._key]) + + def _remove_dict(self, obs): + return obs[self._key] + + def reset(self): + return self._remove_dict(self._venv.reset()) + + def step_wait(self): + obs, rews, dones, infos = self._venv.step_wait() + return self._remove_dict(obs), rews, dones, infos \ No newline at end of file