make_atari_env compatible with mpi

This commit is contained in:
Peter Zhokhov
2018-08-01 14:46:18 -07:00
parent 3528f7b992
commit fcd84aa831
2 changed files with 4 additions and 4 deletions

View File

@@ -21,11 +21,12 @@ def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
Create a wrapped, monitored SubprocVecEnv for Atari.
"""
if wrapper_kwargs is None: wrapper_kwargs = {}
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
def make_env(rank): # pylint: disable=C0111
def _thunk():
env = make_atari(env_id)
env.seed(seed + rank if seed is not None else None)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
return wrap_deepmind(env, **wrapper_kwargs)
return _thunk
set_global_seeds(seed)

View File

@@ -72,8 +72,7 @@ def build_env(args, render=False):
nenv = args.num_env or ncpu if not render else 1
alg = args.alg
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
seed = args.seed + 10000 * rank if args.seed is not None else None
seed = args.seed
env_type, env_id = get_env_type(args.env)
if env_type == 'mujoco':