Update cmd_util with initializer, env_kwargs, and force_dummy
This commit is contained in:
@@ -23,34 +23,44 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
start_index=0,
|
||||
reward_scale=1.0,
|
||||
flatten_dict_observations=True,
|
||||
gamestate=None):
|
||||
gamestate=None,
|
||||
initializer=None,
|
||||
env_kwargs=None,
|
||||
force_dummy=False):
|
||||
"""
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||
"""
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
seed = seed + 10000 * mpi_rank if seed is not None else None
|
||||
def make_thunk(rank):
|
||||
logger_dir = logger.get_dir()
|
||||
def make_thunk(rank, initializer=None):
|
||||
return lambda: make_env(
|
||||
env_id=env_id,
|
||||
env_type=env_type,
|
||||
subrank = rank,
|
||||
mpi_rank=mpi_rank,
|
||||
subrank=rank,
|
||||
seed=seed,
|
||||
reward_scale=reward_scale,
|
||||
gamestate=gamestate,
|
||||
flatten_dict_observations=flatten_dict_observations,
|
||||
wrapper_kwargs=wrapper_kwargs
|
||||
wrapper_kwargs=wrapper_kwargs,
|
||||
logger_dir=logger_dir,
|
||||
initializer=initializer,
|
||||
env_kwargs=env_kwargs,
|
||||
)
|
||||
|
||||
set_global_seeds(seed)
|
||||
if num_env > 1:
|
||||
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
|
||||
if not force_dummy and num_env > 1:
|
||||
return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)])
|
||||
else:
|
||||
return DummyVecEnv([make_thunk(start_index)])
|
||||
return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)])
|
||||
|
||||
|
||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None):
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None, initializer=None, env_kwargs=None):
|
||||
if initializer is not None:
|
||||
initializer(mpi_rank=mpi_rank, subrank=subrank)
|
||||
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
if env_type == 'atari':
|
||||
env = make_atari(env_id)
|
||||
@@ -59,7 +69,7 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate
|
||||
gamestate = gamestate or retro.State.DEFAULT
|
||||
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
|
||||
else:
|
||||
env = gym.make(env_id)
|
||||
env = gym.make(env_id, **(env_kwargs or {}))
|
||||
|
||||
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
|
||||
keys = env.observation_space.spaces.keys()
|
||||
@@ -67,7 +77,7 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate
|
||||
|
||||
env.seed(seed + subrank if seed is not None else None)
|
||||
env = Monitor(env,
|
||||
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
|
||||
if env_type == 'atari':
|
||||
@@ -134,6 +144,7 @@ def common_arg_parser():
|
||||
"""
|
||||
parser = arg_parser()
|
||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
||||
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
||||
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
||||
|
Reference in New Issue
Block a user