From 6bbc4635e65f5c70cf2eaa5b8035ca9a85b734dd Mon Sep 17 00:00:00 2001 From: Greg Brockman Date: Mon, 18 Mar 2019 17:42:10 -0700 Subject: [PATCH] Update cmd_util with initializer, env_kwargs, and force_dummy --- baselines/common/cmd_util.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 24b5b90..76a5044 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -23,34 +23,44 @@ def make_vec_env(env_id, env_type, num_env, seed, start_index=0, reward_scale=1.0, flatten_dict_observations=True, - gamestate=None): + gamestate=None, + initializer=None, + env_kwargs=None, + force_dummy=False): """ Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo. """ wrapper_kwargs = wrapper_kwargs or {} mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 seed = seed + 10000 * mpi_rank if seed is not None else None - def make_thunk(rank): + logger_dir = logger.get_dir() + def make_thunk(rank, initializer=None): return lambda: make_env( env_id=env_id, env_type=env_type, - subrank = rank, + mpi_rank=mpi_rank, + subrank=rank, seed=seed, reward_scale=reward_scale, gamestate=gamestate, flatten_dict_observations=flatten_dict_observations, - wrapper_kwargs=wrapper_kwargs + wrapper_kwargs=wrapper_kwargs, + logger_dir=logger_dir, + initializer=initializer, + env_kwargs=env_kwargs, ) set_global_seeds(seed) - if num_env > 1: - return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)]) + if not force_dummy and num_env > 1: + return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)]) else: - return DummyVecEnv([make_thunk(start_index)]) + return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)]) -def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None): - mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 +def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None, initializer=None, env_kwargs=None): + if initializer is not None: + initializer(mpi_rank=mpi_rank, subrank=subrank) + wrapper_kwargs = wrapper_kwargs or {} if env_type == 'atari': env = make_atari(env_id) @@ -59,7 +69,7 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate gamestate = gamestate or retro.State.DEFAULT env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate) else: - env = gym.make(env_id) + env = gym.make(env_id, **(env_kwargs or {})) if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): keys = env.observation_space.spaces.keys() @@ -67,7 +77,7 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate env.seed(seed + subrank if seed is not None else None) env = Monitor(env, - logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)), + logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)), allow_early_resets=True) if env_type == 'atari': @@ -134,6 +144,7 @@ def common_arg_parser(): """ parser = arg_parser() parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') + parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str) parser.add_argument('--seed', help='RNG seed', type=int, default=None) parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2') parser.add_argument('--num_timesteps', type=float, default=1e6),