parse colon-separated env_id's
This commit is contained in:
@@ -150,7 +150,6 @@ def common_arg_parser():
|
|||||||
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
||||||
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
||||||
parser.add_argument('--play', default=False, action='store_true')
|
parser.add_argument('--play', default=False, action='store_true')
|
||||||
parser.add_argument('--extra_import', help='Extra module to import to access external environments', type=str, default=None)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def robotics_arg_parser():
|
def robotics_arg_parser():
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import re
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import gym
|
import gym
|
||||||
@@ -137,6 +138,8 @@ def get_env_type(args):
|
|||||||
if env_id in e:
|
if env_id in e:
|
||||||
env_type = g
|
env_type = g
|
||||||
break
|
break
|
||||||
|
if ':' in env_id:
|
||||||
|
env_type = re.sub(r':.*', '', env_id)
|
||||||
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
|
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
|
||||||
|
|
||||||
return env_type, env_id
|
return env_type, env_id
|
||||||
@@ -197,9 +200,6 @@ def main(args):
|
|||||||
args, unknown_args = arg_parser.parse_known_args(args)
|
args, unknown_args = arg_parser.parse_known_args(args)
|
||||||
extra_args = parse_cmdline_kwargs(unknown_args)
|
extra_args = parse_cmdline_kwargs(unknown_args)
|
||||||
|
|
||||||
if args.extra_import is not None:
|
|
||||||
import_module(args.extra_import)
|
|
||||||
|
|
||||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
rank = 0
|
rank = 0
|
||||||
logger.configure()
|
logger.configure()
|
||||||
|
Reference in New Issue
Block a user