From fcd84aa831f8a5632b6dac85a93c3a7572f58b38 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Wed, 1 Aug 2018 14:46:18 -0700 Subject: [PATCH] make_atari_env compatible with mpi --- baselines/common/cmd_util.py | 5 +++-- baselines/run.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index e3c770a..681a80c 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -21,11 +21,12 @@ def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0): Create a wrapped, monitored SubprocVecEnv for Atari. """ if wrapper_kwargs is None: wrapper_kwargs = {} + mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 def make_env(rank): # pylint: disable=C0111 def _thunk(): env = make_atari(env_id) - env.seed(seed + rank if seed is not None else None) - env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) + env.seed(seed + 10000*mpi_rank + rank if seed is not None else None) + env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank))) return wrap_deepmind(env, **wrapper_kwargs) return _thunk set_global_seeds(seed) diff --git a/baselines/run.py b/baselines/run.py index 9879d1c..6423e3a 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -72,8 +72,7 @@ def build_env(args, render=False): nenv = args.num_env or ncpu if not render else 1 alg = args.alg rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 - - seed = args.seed + 10000 * rank if args.seed is not None else None + seed = args.seed env_type, env_id = get_env_type(args.env) if env_type == 'mujoco':