Vecenv refactor (#223)
* update karl util * restore pvi flag * change rcall auto cpu behavior, move gin.configurable, add os.makedirs * vecenv refactor * aux buf index fix * add num aux obs * reset level with enter * restore high difficulty flag * bugfix * restore train_coinrun.py * tweaks * renaming * renaming * better arguments handling * more options * options cleanup * game data refactor * more options * args for train_procgen * add close handler to interactive base class * use debug build if debug=True, fix range on aux_obs * add ProcGenEnv to __init__.py, add missing imports to procgen.py * export RemoveDictWrapper and build, update train_procgen.py, move assets download into env creation and replace init_assets_and_build with just build * fix formatting issues * only call global init once * fix path in setup.py * revert part of makefile * ignore IDE files and folders * vec remove dict * export VecRemoveDictObs * remove RemoveDictWrapper * remove IDE files * move shared .h and .cpp files to common folder, update build to use those, dedupe env.cpp * fix missing header * try unified build function * remove old scripts dir * add comment on build * upload libenv with render fixes * tell qthreads to die when we unload the library * pyglet.app.run is garbage * static fixes * whoops * actually vsync is on * cleanup * cleanup * extern C for libenv interface * parse util rcall arg * high difficulty fix * game type enums * ProcGenEnv subclasses * game type cleanup * unrecognized key * unrecognized game type * parse util reorg * args management * typo fix * GinParser * arg tweaks * tweak * restore start_level/num_levels setting * fix create_procgen_env interface * build fix * procgen args in init signature * fix * build fix * fix logger usage in ppo_metal/run_retro
This commit is contained in:
committed by
Peter Zhokhov
parent
d760c363bc
commit
1d56af90d3
@@ -5,5 +5,6 @@ from .subproc_vec_env import SubprocVecEnv
|
||||
from .vec_frame_stack import VecFrameStack
|
||||
from .vec_monitor import VecMonitor
|
||||
from .vec_normalize import VecNormalize
|
||||
from .vec_remove_dict_obs import VecRemoveDictObs
|
||||
|
||||
__all__ = ['AlreadySteppingError', 'NotSteppingError', 'VecEnv', 'VecEnvWrapper', 'VecEnvObservationWrapper', 'CloudpickleWrapper', 'DummyVecEnv', 'ShmemVecEnv', 'SubprocVecEnv', 'VecFrameStack', 'VecMonitor', 'VecNormalize']
|
||||
__all__ = ['AlreadySteppingError', 'NotSteppingError', 'VecEnv', 'VecEnvWrapper', 'VecEnvObservationWrapper', 'CloudpickleWrapper', 'DummyVecEnv', 'ShmemVecEnv', 'SubprocVecEnv', 'VecFrameStack', 'VecMonitor', 'VecNormalize', 'VecRemoveDictObs']
|
||||
|
22
baselines/common/vec_env/vec_remove_dict_obs.py
Normal file
22
baselines/common/vec_env/vec_remove_dict_obs.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .vec_env import VecEnvWrapper
|
||||
|
||||
|
||||
class VecRemoveDictObs(VecEnvWrapper):
|
||||
"""
|
||||
PPO2 doesn't support dictionary observations, so make the environment only expose the observation for the provided key.
|
||||
"""
|
||||
|
||||
def __init__(self, venv, key):
|
||||
self._key = key
|
||||
self._venv = venv
|
||||
super().__init__(venv, observation_space=venv.observation_space.spaces[self._key])
|
||||
|
||||
def _remove_dict(self, obs):
|
||||
return obs[self._key]
|
||||
|
||||
def reset(self):
|
||||
return self._remove_dict(self._venv.reset())
|
||||
|
||||
def step_wait(self):
|
||||
obs, rews, dones, infos = self._venv.step_wait()
|
||||
return self._remove_dict(obs), rews, dones, infos
|
Reference in New Issue
Block a user