make_atari_env compatible with mpi
This commit is contained in:
@@ -21,11 +21,12 @@ def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari.
|
||||
"""
|
||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
def make_env(rank): # pylint: disable=C0111
|
||||
def _thunk():
|
||||
env = make_atari(env_id)
|
||||
env.seed(seed + rank if seed is not None else None)
|
||||
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
|
||||
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
|
||||
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
|
||||
return wrap_deepmind(env, **wrapper_kwargs)
|
||||
return _thunk
|
||||
set_global_seeds(seed)
|
||||
|
@@ -72,8 +72,7 @@ def build_env(args, render=False):
|
||||
nenv = args.num_env or ncpu if not render else 1
|
||||
alg = args.alg
|
||||
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
|
||||
seed = args.seed + 10000 * rank if args.seed is not None else None
|
||||
seed = args.seed
|
||||
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
if env_type == 'mujoco':
|
||||
|
Reference in New Issue
Block a user