diff --git a/baselines/run.py b/baselines/run.py index 2210b51..faa2786 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -21,9 +21,19 @@ try: except ImportError: MPI = None +try: + import pybullet_envs +except ImportError: + pybullet_envs = None + +try: + import roboschool +except ImportError: + roboschool = None + _game_envs = defaultdict(set) for env in gym.envs.registry.all(): - # solve this with regexes + # TODO: solve this with regexes env_type = env._entry_point.split(':')[0].split('.')[-1] _game_envs[env_type].add(env.id) @@ -44,6 +54,7 @@ _game_envs['retro'] = { def train(args, extra_args): env_type, env_id = get_env_type(args.env) + print(f'env_type: {env_type}') total_timesteps = int(args.num_timesteps) seed = args.seed