From adaa8aefa828c8b426e9ecf6e15f5990b3c1abef Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Mon, 10 Sep 2018 11:50:59 -0700 Subject: [PATCH] baselines issue #564 (#574) * fixes to enjoy_cartpole, enjoy_mountaincar.py * fixed {train,enjoy}_pong, removed enjoy_retro * set number of timesteps to 1e7 in train_pong * flake8 complaints * use synchronous version fo acktr in test_env_after_learn * flake8 --- baselines/common/tf_util.py | 8 ++- baselines/deepq/deepq.py | 17 +++--- baselines/deepq/experiments/enjoy_cartpole.py | 2 +- .../deepq/experiments/enjoy_mountaincar.py | 8 ++- baselines/deepq/experiments/enjoy_pong.py | 11 +++- baselines/deepq/experiments/enjoy_retro.py | 34 ------------ baselines/deepq/experiments/run_atari.py | 52 ------------------- baselines/deepq/experiments/run_retro.py | 49 ----------------- .../deepq/experiments/train_mountaincar.py | 6 +-- baselines/deepq/experiments/train_pong.py | 36 +++++++++++++ 10 files changed, 70 insertions(+), 153 deletions(-) delete mode 100644 baselines/deepq/experiments/enjoy_retro.py delete mode 100644 baselines/deepq/experiments/run_atari.py delete mode 100644 baselines/deepq/experiments/run_retro.py create mode 100644 baselines/deepq/experiments/train_pong.py diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 92dde9a..a40b109 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -322,7 +322,9 @@ def save_state(fname, sess=None): from baselines import logger logger.warn('save_state method is deprecated, please use save_variables instead') sess = sess or get_session() - os.makedirs(os.path.dirname(fname), exist_ok=True) + dirname = os.path.dirname(fname) + if any(dirname): + os.makedirs(dirname, exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) @@ -335,7 +337,9 @@ def save_variables(save_path, variables=None, sess=None): ps = sess.run(variables) save_dict = {v.name: value for v, value in zip(variables, ps)} - os.makedirs(os.path.dirname(save_path), exist_ok=True) + dirname = os.path.dirname(save_path) + if any(dirname): + os.makedirs(dirname, exist_ok=True) joblib.dump(save_dict, save_path) def load_variables(load_path, variables=None, sess=None): diff --git a/baselines/deepq/deepq.py b/baselines/deepq/deepq.py index 01921bb..5a4b2e7 100644 --- a/baselines/deepq/deepq.py +++ b/baselines/deepq/deepq.py @@ -7,7 +7,7 @@ import cloudpickle import numpy as np import baselines.common.tf_util as U -from baselines.common.tf_util import load_state, save_state +from baselines.common.tf_util import load_variables, save_variables from baselines import logger from baselines.common.schedules import LinearSchedule from baselines.common import set_global_seeds @@ -39,7 +39,7 @@ class ActWrapper(object): f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) - load_state(os.path.join(td, "model")) + load_variables(os.path.join(td, "model")) return ActWrapper(act, act_params) @@ -55,7 +55,7 @@ class ActWrapper(object): path = os.path.join(logger.get_dir(), "model.pkl") with tempfile.TemporaryDirectory() as td: - save_state(os.path.join(td, "model")) + save_variables(os.path.join(td, "model")) arc_name = os.path.join(td, "packed.zip") with zipfile.ZipFile(arc_name, 'w') as zipf: for root, dirs, files in os.walk(td): @@ -69,8 +69,7 @@ class ActWrapper(object): cloudpickle.dump((model_data, self._act_params), f) def save(self, path): - save_state(path) - self.save_act(path+".pickle") + save_variables(path) def load_act(path): @@ -249,11 +248,11 @@ def learn(env, model_saved = False if tf.train.latest_checkpoint(td) is not None: - load_state(model_file) + load_variables(model_file) logger.log('Loaded model from {}'.format(model_file)) model_saved = True elif load_path is not None: - load_state(load_path) + load_variables(load_path) logger.log('Loaded model from {}'.format(load_path)) @@ -322,12 +321,12 @@ def learn(env, if print_freq is not None: logger.log("Saving model due to mean reward increase: {} -> {}".format( saved_mean_reward, mean_100ep_reward)) - save_state(model_file) + save_variables(model_file) model_saved = True saved_mean_reward = mean_100ep_reward if model_saved: if print_freq is not None: logger.log("Restored model with mean reward: {}".format(saved_mean_reward)) - load_state(model_file) + load_variables(model_file) return act diff --git a/baselines/deepq/experiments/enjoy_cartpole.py b/baselines/deepq/experiments/enjoy_cartpole.py index 1c6176b..b7d5ef1 100644 --- a/baselines/deepq/experiments/enjoy_cartpole.py +++ b/baselines/deepq/experiments/enjoy_cartpole.py @@ -5,7 +5,7 @@ from baselines import deepq def main(): env = gym.make("CartPole-v0") - act = deepq.load("cartpole_model.pkl") + act = deepq.learn(env, network='mlp', total_timesteps=0, load_path="cartpole_model.pkl") while True: obs, done = env.reset(), False diff --git a/baselines/deepq/experiments/enjoy_mountaincar.py b/baselines/deepq/experiments/enjoy_mountaincar.py index 8bced8c..8b1089e 100644 --- a/baselines/deepq/experiments/enjoy_mountaincar.py +++ b/baselines/deepq/experiments/enjoy_mountaincar.py @@ -1,11 +1,17 @@ import gym from baselines import deepq +from baselines.common import models def main(): env = gym.make("MountainCar-v0") - act = deepq.load("mountaincar_model.pkl") + act = deepq.learn( + env, + network=models.mlp(num_layers=1, num_hidden=64), + total_timesteps=0, + load_path='mountaincar_model.pkl' + ) while True: obs, done = env.reset(), False diff --git a/baselines/deepq/experiments/enjoy_pong.py b/baselines/deepq/experiments/enjoy_pong.py index 5b16fec..0b118c7 100644 --- a/baselines/deepq/experiments/enjoy_pong.py +++ b/baselines/deepq/experiments/enjoy_pong.py @@ -5,14 +5,21 @@ from baselines import deepq def main(): env = gym.make("PongNoFrameskip-v4") env = deepq.wrap_atari_dqn(env) - act = deepq.load("pong_model.pkl") + model = deepq.learn( + env, + "conv_only", + convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], + hiddens=[256], + dueling=True, + total_timesteps=0 + ) while True: obs, done = env.reset(), False episode_rew = 0 while not done: env.render() - obs, rew, done, _ = env.step(act(obs[None])[0]) + obs, rew, done, _ = env.step(model(obs[None])[0]) episode_rew += rew print("Episode reward", episode_rew) diff --git a/baselines/deepq/experiments/enjoy_retro.py b/baselines/deepq/experiments/enjoy_retro.py deleted file mode 100644 index 526af16..0000000 --- a/baselines/deepq/experiments/enjoy_retro.py +++ /dev/null @@ -1,34 +0,0 @@ -import argparse - -import numpy as np - -from baselines import deepq -from baselines.common import retro_wrappers - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--env', help='environment ID', default='SuperMarioBros-Nes') - parser.add_argument('--gamestate', help='game state to load', default='Level1-1') - parser.add_argument('--model', help='model pickle file from ActWrapper.save', default='model.pkl') - args = parser.parse_args() - - env = retro_wrappers.make_retro(game=args.env, state=args.gamestate, max_episode_steps=None) - env = retro_wrappers.wrap_deepmind_retro(env) - act = deepq.load(args.model) - - while True: - obs, done = env.reset(), False - episode_rew = 0 - while not done: - env.render() - action = act(obs[None])[0] - env_action = np.zeros(env.action_space.n) - env_action[action] = 1 - obs, rew, done, _ = env.step(env_action) - episode_rew += rew - print('Episode reward', episode_rew) - - -if __name__ == '__main__': - main() diff --git a/baselines/deepq/experiments/run_atari.py b/baselines/deepq/experiments/run_atari.py deleted file mode 100644 index aa60001..0000000 --- a/baselines/deepq/experiments/run_atari.py +++ /dev/null @@ -1,52 +0,0 @@ -from baselines import deepq -from baselines.common import set_global_seeds -from baselines import bench -import argparse -from baselines import logger -from baselines.common.atari_wrappers import make_atari - - -def main(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') - parser.add_argument('--seed', help='RNG seed', type=int, default=0) - parser.add_argument('--prioritized', type=int, default=1) - parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) - parser.add_argument('--dueling', type=int, default=1) - parser.add_argument('--num-timesteps', type=int, default=int(10e6)) - parser.add_argument('--checkpoint-freq', type=int, default=10000) - parser.add_argument('--checkpoint-path', type=str, default=None) - - args = parser.parse_args() - logger.configure() - set_global_seeds(args.seed) - env = make_atari(args.env) - env = bench.Monitor(env, logger.get_dir()) - env = deepq.wrap_atari_dqn(env) - - deepq.learn( - env, - "conv_only", - convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], - hiddens=[256], - dueling=bool(args.dueling), - lr=1e-4, - total_timesteps=args.num_timesteps, - buffer_size=10000, - exploration_fraction=0.1, - exploration_final_eps=0.01, - train_freq=4, - learning_starts=10000, - target_network_update_freq=1000, - gamma=0.99, - prioritized_replay=bool(args.prioritized), - prioritized_replay_alpha=args.prioritized_replay_alpha, - checkpoint_freq=args.checkpoint_freq, - checkpoint_path=args.checkpoint_path, - ) - - env.close() - - -if __name__ == '__main__': - main() diff --git a/baselines/deepq/experiments/run_retro.py b/baselines/deepq/experiments/run_retro.py deleted file mode 100644 index 0338361..0000000 --- a/baselines/deepq/experiments/run_retro.py +++ /dev/null @@ -1,49 +0,0 @@ -import argparse - -from baselines import deepq -from baselines.common import set_global_seeds -from baselines import bench -from baselines import logger -from baselines.common import retro_wrappers -import retro - - -def main(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--env', help='environment ID', default='SuperMarioBros-Nes') - parser.add_argument('--gamestate', help='game state to load', default='Level1-1') - parser.add_argument('--seed', help='seed', type=int, default=0) - parser.add_argument('--num-timesteps', type=int, default=int(10e6)) - args = parser.parse_args() - logger.configure() - set_global_seeds(args.seed) - env = retro_wrappers.make_retro(game=args.env, state=args.gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE) - env.seed(args.seed) - env = bench.Monitor(env, logger.get_dir()) - env = retro_wrappers.wrap_deepmind_retro(env) - - model = deepq.models.cnn_to_mlp( - convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], - hiddens=[256], - dueling=True - ) - act = deepq.learn( - env, - q_func=model, - lr=1e-4, - max_timesteps=args.num_timesteps, - buffer_size=10000, - exploration_fraction=0.1, - exploration_final_eps=0.01, - train_freq=4, - learning_starts=10000, - target_network_update_freq=1000, - gamma=0.99, - prioritized_replay=True - ) - act.save() - env.close() - - -if __name__ == '__main__': - main() diff --git a/baselines/deepq/experiments/train_mountaincar.py b/baselines/deepq/experiments/train_mountaincar.py index 061967d..fff678a 100644 --- a/baselines/deepq/experiments/train_mountaincar.py +++ b/baselines/deepq/experiments/train_mountaincar.py @@ -1,17 +1,17 @@ import gym from baselines import deepq +from baselines.common import models def main(): env = gym.make("MountainCar-v0") # Enabling layer_norm here is import for parameter space noise! - model = deepq.models.mlp([64], layer_norm=True) act = deepq.learn( env, - q_func=model, + network=models.mlp(num_hidden=64, num_layers=1), lr=1e-3, - max_timesteps=100000, + total_timesteps=100000, buffer_size=50000, exploration_fraction=0.1, exploration_final_eps=0.1, diff --git a/baselines/deepq/experiments/train_pong.py b/baselines/deepq/experiments/train_pong.py new file mode 100644 index 0000000..a8febb9 --- /dev/null +++ b/baselines/deepq/experiments/train_pong.py @@ -0,0 +1,36 @@ +from baselines import deepq +from baselines import bench +from baselines import logger +from baselines.common.atari_wrappers import make_atari + + +def main(): + logger.configure() + env = make_atari('PongNoFrameskip-v4') + env = bench.Monitor(env, logger.get_dir()) + env = deepq.wrap_atari_dqn(env) + + model = deepq.learn( + env, + "conv_only", + convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], + hiddens=[256], + dueling=True, + lr=1e-4, + total_timesteps=int(1e7), + buffer_size=10000, + exploration_fraction=0.1, + exploration_final_eps=0.01, + train_freq=4, + learning_starts=10000, + target_network_update_freq=1000, + gamma=0.99, + ) + + model.save('pong_model.pkl') + env.close() + + + +if __name__ == '__main__': + main()