From 4d0746b9579ebf61401276921fb84dc4268abf9e Mon Sep 17 00:00:00 2001 From: Alex Ray Date: Thu, 3 Jan 2019 11:33:31 -0800 Subject: [PATCH] add an argument for importing extra modules from run --- baselines/common/cmd_util.py | 1 + baselines/run.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 650911e..24b5b90 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -145,6 +145,7 @@ def common_arg_parser(): parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int) parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int) parser.add_argument('--play', default=False, action='store_true') + parser.add_argument('--extra_import', help='Extra module to import to access external environments', type=str, default=None) return parser def robotics_arg_parser(): diff --git a/baselines/run.py b/baselines/run.py index a493071..cc63c11 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -120,6 +120,11 @@ def build_env(args): def get_env_type(env_id): + # Re-parse the gym registry, since we could have new envs since last time. + for env in gym.envs.registry.all(): + env_type = env._entry_point.split(':')[0].split('.')[-1] + _game_envs[env_type].add(env.id) # This is a set so add is idempotent + if env_id in _game_envs.keys(): env_type = env_id env_id = [g for g in _game_envs[env_type]][0] @@ -189,6 +194,9 @@ def main(args): args, unknown_args = arg_parser.parse_known_args(args) extra_args = parse_cmdline_kwargs(unknown_args) + if args.extra_import is not None: + import_module(args.extra_import) + if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: rank = 0 logger.configure()