add an argument for importing extra modules from run
This commit is contained in:
@@ -120,6 +120,11 @@ def build_env(args):
|
||||
|
||||
|
||||
def get_env_type(env_id):
|
||||
# Re-parse the gym registry, since we could have new envs since last time.
|
||||
for env in gym.envs.registry.all():
|
||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||
_game_envs[env_type].add(env.id) # This is a set so add is idempotent
|
||||
|
||||
if env_id in _game_envs.keys():
|
||||
env_type = env_id
|
||||
env_id = [g for g in _game_envs[env_type]][0]
|
||||
@@ -189,6 +194,9 @@ def main(args):
|
||||
args, unknown_args = arg_parser.parse_known_args(args)
|
||||
extra_args = parse_cmdline_kwargs(unknown_args)
|
||||
|
||||
if args.extra_import is not None:
|
||||
import_module(args.extra_import)
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
rank = 0
|
||||
logger.configure()
|
||||
|
Reference in New Issue
Block a user