diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index d69589c..44dafa1 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -16,30 +16,56 @@ from baselines.common import set_global_seeds from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.vec_env.dummy_vec_env import DummyVecEnv -from baselines.common.retro_wrappers import RewardScaler +from baselines.common import retro_wrappers - -def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0): +def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None): """ Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo. """ if wrapper_kwargs is None: wrapper_kwargs = {} mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 - def make_env(rank): # pylint: disable=C0111 - def _thunk(): - env = make_atari(env_id) if env_type == 'atari' else gym.make(env_id) - env.seed(seed + 10000*mpi_rank + rank if seed is not None else None) - env = Monitor(env, - logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)), - allow_early_resets=True) + seed = seed + 10000 * mpi_rank if seed is not None else None + def make_thunk(rank): + return lambda: make_env( + env_id=env_id, + env_type=env_type, + subrank = rank, + seed=seed, + reward_scale=reward_scale, + gamestate=gamestate + ) - if env_type == 'atari': return wrap_deepmind(env, **wrapper_kwargs) - elif reward_scale != 1: return RewardScaler(env, reward_scale) - else: return env - return _thunk set_global_seeds(seed) - if num_env > 1: return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) - else: return DummyVecEnv([make_env(start_index)]) + if num_env > 1: + return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)]) + else: + return DummyVecEnv([make_thunk(start_index)]) + + +def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs=None): + mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 + if env_type == 'atari': + env = make_atari(env_id) + elif env_type == 'retro': + import retro + gamestate = gamestate or retro.State.DEFAULT + env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate) + else: + env = gym.make(env_id) + + env.seed(seed + subrank if seed is not None else None) + env = Monitor(env, + logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)), + allow_early_resets=True) + + if env_type == 'atari': + return wrap_deepmind(env, **wrapper_kwargs) + elif reward_scale != 1: + return retro_wrappers.RewardScaler(env, reward_scale) + else: + return env + + def make_mujoco_env(env_id, seed, reward_scale=1.0): """ diff --git a/baselines/run.py b/baselines/run.py index 8ab71ac..dedca8b 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -7,13 +7,12 @@ import tensorflow as tf import numpy as np from baselines.common.vec_env.vec_frame_stack import VecFrameStack -from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env +from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env from baselines.common.tf_util import get_session -from baselines import bench, logger +from baselines import logger from importlib import import_module from baselines.common.vec_env.vec_normalize import VecNormalize -from baselines.common import atari_wrappers, retro_wrappers try: from mpi4py import MPI @@ -87,38 +86,21 @@ def build_env(args): if sys.platform == 'darwin': ncpu //= 2 nenv = args.num_env or ncpu alg = args.alg - rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 seed = args.seed env_type, env_id = get_env_type(args.env) - if env_type == 'atari': + if env_type in {'atari', 'retro'}: if alg == 'acer': env = make_vec_env(env_id, env_type, nenv, seed) elif alg == 'deepq': - env = atari_wrappers.make_atari(env_id) - env.seed(seed) - env = bench.Monitor(env, logger.get_dir()) - env = atari_wrappers.wrap_deepmind(env, frame_stack=True) + env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True}) elif alg == 'trpo_mpi': - env = atari_wrappers.make_atari(env_id) - env.seed(seed) - env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank))) - env = atari_wrappers.wrap_deepmind(env) - # TODO check if the second seeding is necessary, and eventually remove - env.seed(seed) + env = make_env(env_id, env_type, seed=seed) else: frame_stack_size = 4 - env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size) - - elif env_type == 'retro': - import retro - gamestate = args.gamestate or retro.State.DEFAULT - env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000, - use_restricted_actions=retro.Actions.DISCRETE) - env.seed(args.seed) - env = bench.Monitor(env, logger.get_dir()) - env = retro_wrappers.wrap_deepmind_retro(env) + env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale) + env = VecFrameStack(env, frame_stack_size) else: config = tf.ConfigProto(allow_soft_placement=True,