Merge pull request #777 from openai/aray-extra-imports
add an argument for importing extra modules from run
This commit is contained in:
@@ -145,6 +145,7 @@ def common_arg_parser():
|
|||||||
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
||||||
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
||||||
parser.add_argument('--play', default=False, action='store_true')
|
parser.add_argument('--play', default=False, action='store_true')
|
||||||
|
parser.add_argument('--extra_import', help='Extra module to import to access external environments', type=str, default=None)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def robotics_arg_parser():
|
def robotics_arg_parser():
|
||||||
|
@@ -120,6 +120,11 @@ def build_env(args):
|
|||||||
|
|
||||||
|
|
||||||
def get_env_type(env_id):
|
def get_env_type(env_id):
|
||||||
|
# Re-parse the gym registry, since we could have new envs since last time.
|
||||||
|
for env in gym.envs.registry.all():
|
||||||
|
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||||
|
_game_envs[env_type].add(env.id) # This is a set so add is idempotent
|
||||||
|
|
||||||
if env_id in _game_envs.keys():
|
if env_id in _game_envs.keys():
|
||||||
env_type = env_id
|
env_type = env_id
|
||||||
env_id = [g for g in _game_envs[env_type]][0]
|
env_id = [g for g in _game_envs[env_type]][0]
|
||||||
@@ -189,6 +194,9 @@ def main(args):
|
|||||||
args, unknown_args = arg_parser.parse_known_args(args)
|
args, unknown_args = arg_parser.parse_known_args(args)
|
||||||
extra_args = parse_cmdline_kwargs(unknown_args)
|
extra_args = parse_cmdline_kwargs(unknown_args)
|
||||||
|
|
||||||
|
if args.extra_import is not None:
|
||||||
|
import_module(args.extra_import)
|
||||||
|
|
||||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
rank = 0
|
rank = 0
|
||||||
logger.configure()
|
logger.configure()
|
||||||
|
Reference in New Issue
Block a user