diff --git a/baselines/run.py b/baselines/run.py index f170225..2aceb69 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -1,5 +1,5 @@ import sys -import multiprocessing +import multiprocessing import os.path as osp import gym from collections import defaultdict @@ -30,7 +30,7 @@ for env in gym.envs.registry.all(): # reading benchmark names directly from retro requires # importing retro here, and for some reason that crashes tensorflow # in ubuntu -_game_envs['retro'] = set([ +_game_envs['retro'] = { 'BubbleBobble-Nes', 'SuperMarioBros-Nes', 'TwinBee3PokoPokoDaimaou-Nes', @@ -39,7 +39,7 @@ _game_envs['retro'] = set([ 'Vectorman-Genesis', 'FinalFight-Snes', 'SpaceInvaders-Snes', -]) +} def train(args, extra_args): @@ -60,8 +60,6 @@ def train(args, extra_args): if alg_kwargs.get('network') is None: alg_kwargs['network'] = get_default_network(env_type) - - print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs)) model = learn( @@ -91,7 +89,7 @@ def build_env(args): if args.num_env: env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale) else: - env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale) + env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale) env = VecNormalize(env) @@ -117,7 +115,8 @@ def build_env(args): elif env_type == 'retro': import retro gamestate = args.gamestate or 'Level1-1' - env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE) + env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000, + use_restricted_actions=retro.Actions.DISCRETE) env.seed(args.seed) env = bench.Monitor(env, logger.get_dir()) env = retro_wrappers.wrap_deepmind_retro(env) @@ -140,7 +139,7 @@ def build_env(args): def get_env_type(env_id): if env_id in _game_envs.keys(): env_type = env_id - env_id = [g for g in _game_envs[env_type]][0] + env_id = [g for g in _game_envs[env_type]][0] else: env_type = None for g, e in _game_envs.items(): @@ -151,6 +150,7 @@ def get_env_type(env_id): return env_type, env_id + def get_default_network(env_type): if env_type == 'mujoco' or env_type == 'classic_control': return 'mlp' @@ -159,6 +159,7 @@ def get_default_network(env_type): raise ValueError('Unknown env_type {}'.format(env_type)) + def get_alg_module(alg, submodule=None): submodule = submodule or alg try: @@ -174,6 +175,7 @@ def get_alg_module(alg, submodule=None): def get_learn_function(alg): return get_alg_module(alg).learn + def get_learn_function_defaults(alg, env_type): try: alg_defaults = get_alg_module(alg, 'defaults') @@ -182,6 +184,7 @@ def get_learn_function_defaults(alg, env_type): kwargs = {} return kwargs + def parse(v): ''' convert value of a command-line arg to a python object if possible, othewise, keep as string @@ -199,14 +202,13 @@ def main(): arg_parser = common_arg_parser() args, unknown_args = arg_parser.parse_known_args() - extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()} - + extra_args = {k: parse(v) for k, v in parse_unknown_args(unknown_args).items()} if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: rank = 0 logger.configure() else: - logger.configure(format_strs = []) + logger.configure(format_strs=[]) rank = MPI.COMM_WORLD.Get_rank() model, _ = train(args, extra_args) @@ -215,14 +217,13 @@ def main(): save_path = osp.expanduser(args.save_path) model.save(save_path) - if args.play: logger.log("Running trained model") env = build_env(args) obs = env.reset() while True: actions = model.step(obs)[0] - obs, _, done, _ = env.step(actions) + obs, _, done, _ = env.step(actions) env.render() done = done.any() if isinstance(done, np.ndarray) else done @@ -230,6 +231,5 @@ def main(): obs = env.reset() - if __name__ == '__main__': main()