From 64dfabb8eb53664d300c573dbd456fbd85ce4aad Mon Sep 17 00:00:00 2001 From: Greg Brockman Date: Tue, 23 Apr 2019 13:40:08 -0700 Subject: [PATCH] Add initializer for process-level setup in SubprocVecEnv (#276) * Add initializer for process-level setup in SubprocVecEnv Use case: run logger.configure() in each subprocess * Add option to force dummy vec env --- baselines/common/cmd_util.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 016df93..99ec11c 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -25,7 +25,9 @@ def make_vec_env(env_id, env_type, num_env, seed, start_index=0, reward_scale=1.0, flatten_dict_observations=True, - gamestate=None): + gamestate=None, + initializer=None, + force_dummy=False): """ Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo. """ @@ -34,7 +36,7 @@ def make_vec_env(env_id, env_type, num_env, seed, mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 seed = seed + 10000 * mpi_rank if seed is not None else None logger_dir = logger.get_dir() - def make_thunk(rank): + def make_thunk(rank, initializer=None): return lambda: make_env( env_id=env_id, env_type=env_type, @@ -46,17 +48,21 @@ def make_vec_env(env_id, env_type, num_env, seed, flatten_dict_observations=flatten_dict_observations, wrapper_kwargs=wrapper_kwargs, env_kwargs=env_kwargs, - logger_dir=logger_dir + logger_dir=logger_dir, + initializer=initializer ) set_global_seeds(seed) - if num_env > 1: - return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)]) + if not force_dummy and num_env > 1: + return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)]) else: - return DummyVecEnv([make_thunk(start_index)]) + return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)]) -def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None): +def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None): + if initializer is not None: + initializer(mpi_rank=mpi_rank, subrank=subrank) + wrapper_kwargs = wrapper_kwargs or {} env_kwargs = env_kwargs or {} if ':' in env_id: